コード例 #1
0
ファイル: model.py プロジェクト: XiongWei1012/HOPE-1
def select_model(model_def):
    if model_def.lower() == 'hopenet':
        model = HopeNet()
        print('HopeNet is loaded')
    elif model_def.lower() == 'resnet10':
        model = resnet10(pretrained=False, num_classes=29 * 2)
        print('ResNet10 is loaded')
    elif model_def.lower() == 'resnet18':
        model = resnet18(pretrained=False, num_classes=29 * 2)
        print('ResNet18 is loaded')
    elif model_def.lower() == 'resnet50':
        model = resnet50(pretrained=False, num_classes=29 * 2)
        print('ResNet50 is loaded')
    elif model_def.lower() == 'resnet101':
        model = resnet101(pretrained=False, num_classes=29 * 2)
        print('ResNet101 is loaded')
    elif model_def.lower() == 'graphunet':
        model = GraphUNet(in_features=2, out_features=3)
        print('GraphUNet is loaded')
    elif model_def.lower() == 'graphnet':
        model = GraphNet(in_features=2, out_features=3)
        print('GraphNet is loaded')
    else:
        raise NameError('Undefined model')
    return model
コード例 #2
0
def select_model(model_def):
    if model_def.lower() == 'hourglass':
        model = Net_HM_HG(21)
        print('HourGlass Net is loaded')
    elif model_def.lower() == 'posehg':
        model = Net_Pose_HG(21)
        print('PoseHG Net is loaded')
    elif model_def.lower() == 'graphuhand':
        model = GraphUHandNet()
        print('GraphUHand Net is loaded')
    elif model_def.lower() == 'resnet10':
        # model = resnet10(pretrained=False, num_classes=29*2)
        model = resnet10(pretrained=False, num_classes=21 * 2)
        print('ResNet10 is loaded')
    elif model_def.lower() == 'resnet18':
        # model = resnet18(pretrained=False, num_classes=29*2)
        model = resnet18(pretrained=False, num_classes=21 * 2)
        print('ResNet18 is loaded')
    elif model_def.lower() == 'resnet50':
        # model = resnet50(pretrained=False, num_classes=29*2)
        model = resnet50(pretrained=False, num_classes=21 * 2)
        print('ResNet50 is loaded')
    elif model_def.lower() == 'resnet101':
        # model = resnet101(pretrained=False, num_classes=29*2)
        model = resnet101(pretrained=False, num_classes=21 * 2)
        print('ResNet101 is loaded')
    elif model_def.lower() == 'graphunet':
        model = GraphUNet(in_features=2, out_features=3)
        print('GraphUNet is loaded')
    elif model_def.lower() == 'graphnet':
        model = GraphNet(in_features=2, out_features=3)
        print('GraphNet is loaded')
    else:
        raise NameError('Undefined model')
    return model
コード例 #3
0
def generate_cammodel(config):
    from models.resnet import get_fine_tuning_parameters

    if config.model_depth == 10:
        model = resnet.resnet10(num_classes=config.n_classes,
                                shortcut_type=config.resnet_shortcut,
                                sample_size=config.sample_size,
                                sample_duration=config.sample_duration,
                                channels=config.channels)
    elif config.model_depth == 18:
        model = resnet.resnet18(num_classes=config.n_classes,
                                shortcut_type=config.resnet_shortcut,
                                sample_size=config.sample_size,
                                sample_duration=config.sample_duration,
                                channels=config.channels)
    elif config.model_depth == 34:
        model = resnet.resnet34(num_classes=config.n_classes,
                                shortcut_type=config.resnet_shortcut,
                                sample_size=config.sample_size,
                                sample_duration=config.sample_duration,
                                channels=config.channels)
    elif config.model_depth == 50:
        model = resnet.resnet50(num_classes=config.n_classes,
                                shortcut_type=config.resnet_shortcut,
                                sample_size=config.sample_size,
                                sample_duration=config.sample_duration,
                                channels=config.channels)
    if not config.no_cuda:
        model = model.cuda()
        #model = nn.DataParallel(model, device_ids=None)
    return model, model.parameters()
コード例 #4
0
ファイル: model.py プロジェクト: xjsxujingsong/VGGSound
def Resnet(opt):

    assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

    if opt.model_depth == 10:
        model = resnet.resnet10(
            num_classes=opt.n_classes)
    elif opt.model_depth == 18:
        model = resnet.resnet18(
            num_classes=opt.n_classes,
            pool=opt.pool)
    elif opt.model_depth == 34:
        model = resnet.resnet34(
            num_classes=opt.n_classes,
            pool=opt.pool)
    elif opt.model_depth == 50:
        model = resnet.resnet50(
            num_classes=opt.n_classes,
            pool=opt.pool)
    elif opt.model_depth == 101:
        model = resnet.resnet101(
            num_classes=opt.n_classes)
    elif opt.model_depth == 152:
        model = resnet.resnet152(
            num_classes=opt.n_classes)
    elif opt.model_depth == 200:
        model = resnet.resnet200(
            num_classes=opt.n_classes)
    return model 
コード例 #5
0
def generate_model(opt):
    assert opt.model in ['resnet', 'densenet', 'se_resnet']

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(pretrained=True, num_classes=opt.n_classes)
        elif opt.model_depth == 18:
            model = resnet.resnet18(pretrained=True, num_classes=opt.n_classes)
        elif opt.model_depth == 34:
            model = resnet.resnet34(pretrained=True, num_classes=opt.n_classes)
        elif opt.model_depth == 50:
            model = resnet.resnet50(pretrained=True, num_classes=opt.n_classes)
        elif opt.model_depth == 101:
            model = resnet.resnet101(pretrained=True,
                                     num_classes=opt.n_classes)
        elif opt.model_depth == 152:
            model = resnet.resnet152(pretrained=True,
                                     num_classes=opt.n_classes)
        elif opt.model_depth == 200:
            model = resnet.resnet200(pretrained=True,
                                     num_classes=opt.n_classes)
    elif opt.model == 'se_resnet':
        assert opt.model_depth in [18, 34, 50, 101, 152]

        from models.se_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = se_resnet.resnet18(pretrained=True,
                                       num_classes=opt.n_classes)
        elif opt.model_depth == 34:
            model = se_resnet.resnet34(pretrained=True,
                                       num_classes=opt.n_classes)
        elif opt.model_depth == 50:
            model = se_resnet.resnet50(pretrained=True,
                                       num_classes=opt.n_classes)
        elif opt.model_depth == 101:
            model = se_resnet.resnet101(pretrained=True,
                                        num_classes=opt.n_classes)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
        return model, parameters
    else:
        parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
        return model, parameters

    return model, model.parameters()
コード例 #6
0
ファイル: model.py プロジェクト: PWB97/Deepfake-detection
def get_resnet_3d(num_classes=2,
                  model_depth=10,
                  shortcut_type='B',
                  sample_size=112,
                  sample_duration=16):
    assert model_depth in [10, 18, 34, 50, 101, 152, 200]

    if model_depth == 10:
        model = resnet.resnet10(num_classes=num_classes,
                                shortcut_type=shortcut_type,
                                sample_size=sample_size,
                                sample_duration=sample_duration)
    elif model_depth == 18:
        model = resnet.resnet18(num_classes=num_classes,
                                shortcut_type=shortcut_type,
                                sample_size=sample_size,
                                sample_duration=sample_duration)
    elif model_depth == 34:
        model = resnet.resnet34(num_classes=num_classes,
                                shortcut_type=shortcut_type,
                                sample_size=sample_size,
                                sample_duration=sample_duration)
    elif model_depth == 50:
        model = resnet.resnet50(num_classes=num_classes,
                                shortcut_type=shortcut_type,
                                sample_size=sample_size,
                                sample_duration=sample_duration)
    elif model_depth == 101:
        model = resnet.resnet101(num_classes=num_classes,
                                 shortcut_type=shortcut_type,
                                 sample_size=sample_size,
                                 sample_duration=sample_duration)
    elif model_depth == 152:
        model = resnet.resnet152(num_classes=num_classes,
                                 shortcut_type=shortcut_type,
                                 sample_size=sample_size,
                                 sample_duration=sample_duration)
    else:
        model = resnet.resnet200(num_classes=num_classes,
                                 shortcut_type=shortcut_type,
                                 sample_size=sample_size,
                                 sample_duration=sample_duration)

    return model
コード例 #7
0
def generate_model(opt):
    assert opt.model in ['resnet', 'resnext']

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(opt=opt)
        elif opt.model_depth == 18:
            model = resnet.resnet18(opt=opt)
        elif opt.model_depth == 34:
            model = resnet.resnet34(opt=opt)
        elif opt.model_depth == 50:
            model = resnet.resnet50(opt=opt)
        elif opt.model_depth == 101:
            model = resnet.resnet101(opt=opt)
        elif opt.model_depth == 152:
            model = resnet.resnet152(opt=opt)
        elif opt.model_depth == 200:
            model = resnet.resnet200(opt=opt)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(opt=opt)
        elif opt.model_depth == 101:
            model = resnext.resnet101(opt=opt)
        elif opt.model_depth == 152:
            model = resnext.resnet152(opt=opt)

    if not opt.no_cuda:
        model = model.cuda()

    return model, model.parameters()
コード例 #8
0
ファイル: model.py プロジェクト: LAOS-Y/2Stage-EGFR
def generate_model(opt):
    assert opt.model in [
        'resnet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]
        
        if opt.model_depth == 10:
            model = resnet.resnet10()
        elif opt.model_depth == 18:
            model = resnet.resnet18()
        elif opt.model_depth == 34:
            model = resnet.resnet34()
        elif opt.model_depth == 50:
            model = resnet.resnet50()
        elif opt.model_depth == 101:
            model = resnet.resnet101()
        elif opt.model_depth == 152:
            model = resnet.resnet152()
        elif opt.model_depth == 200:
            model = resnet.resnet200()
    
    return model
コード例 #9
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features,
                    opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path, map_location='cpu')
            assert opt.arch == pretrain['arch']

            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in pretrain['state_dict'].items():
                name = k[7:]  # remove `module.`
                new_state_dict[name] = v
            # load params
            model.load_state_dict(new_state_dict)

            #model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
コード例 #10
0
def generate_model(opt, phase):
    if phase == 'segment':
        assert opt.seg_model in ['deeplab']
        if opt.seg_model == 'deeplab':
            model = deeplab.Net(in_channel=opt.in_channel,
                                num_classes=opt.n_classes)
    elif phase == 'classify':
        assert opt.cla_model in [
            'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
        ]

        if opt.cla_model == 'resnet':
            assert opt.cla_model_depth in [10, 18, 34, 50, 101, 152, 200]

            from models.resnet import get_fine_tuning_parameters

            if opt.cla_model_depth == 10:
                model = resnet.resnet10(num_classes=opt.n_classes,
                                        shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 18:
                model = resnet.resnet18(num_classes=opt.n_classes,
                                        shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 34:
                model = resnet.resnet34(num_classes=opt.n_classes,
                                        shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 50:
                model = resnet.resnet50(num_classes=opt.n_classes,
                                        shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 101:
                model = resnet.resnet101(num_classes=opt.n_classes,
                                         shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 152:
                model = resnet.resnet152(num_classes=opt.n_classes,
                                         shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 200:
                model = resnet.resnet200(num_classes=opt.n_classes,
                                         shortcut_type=opt.cla_resnet_shortcut)
        elif opt.cla_model == 'wideresnet':
            assert opt.cla_model_depth in [50]

            from models.wide_resnet import get_fine_tuning_parameters

            if opt.cla_model_depth == 50:
                model = wide_resnet.resnet50(
                    num_classes=opt.n_classes,
                    shortcut_type=opt.cla_resnet_shortcut,
                    k=opt.wide_resnet_k)
        elif opt.cla_model == 'resnext':
            assert opt.cla_model_depth in [50, 101, 152]

            from models.resnext import get_fine_tuning_parameters

            if opt.cla_model_depth == 50:
                model = resnext.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.cla_resnet_shortcut,
                                         cardinality=opt.resnext_cardinality)
            elif opt.cla_model_depth == 101:
                model = resnext.resnet101(
                    num_classes=opt.n_classes,
                    shortcut_type=opt.cla_resnet_shortcut,
                    cardinality=opt.resnext_cardinality)
            elif opt.cla_model_depth == 152:
                model = resnext.resnet152(
                    num_classes=opt.n_classes,
                    shortcut_type=opt.cla_resnet_shortcut,
                    cardinality=opt.resnext_cardinality)
        elif opt.cla_model == 'preresnet':
            assert opt.cla_model_depth in [18, 34, 50, 101, 152, 200]

            from models.pre_act_resnet import get_fine_tuning_parameters

            if opt.cla_model_depth == 18:
                model = pre_act_resnet.resnet18(
                    num_classes=opt.n_classes,
                    shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 34:
                model = pre_act_resnet.resnet34(
                    num_classes=opt.n_classes,
                    shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 50:
                model = pre_act_resnet.resnet50(
                    num_classes=opt.n_classes,
                    shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 101:
                model = pre_act_resnet.resnet101(
                    num_classes=opt.n_classes,
                    shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 152:
                model = pre_act_resnet.resnet152(
                    num_classes=opt.n_classes,
                    shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 200:
                model = pre_act_resnet.resnet200(
                    num_classes=opt.n_classes,
                    shortcut_type=opt.cla_resnet_shortcut)
        elif opt.cla_model == 'densenet':
            assert opt.cla_model_depth in [121, 169, 201, 264]

            from models.densenet import get_fine_tuning_parameters

            if opt.cla_model_depth == 121:
                model = densenet.densenet121(num_classes=opt.n_classes)
            elif opt.cla_model_depth == 169:
                model = densenet.densenet169(num_classes=opt.n_classes)
            elif opt.cla_model_depth == 201:
                model = densenet.densenet201(num_classes=opt.n_classes)
            elif opt.cla_model_depth == 264:
                model = densenet.densenet264(num_classes=opt.n_classes)
    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)
    return model, model.parameters()
コード例 #11
0
def generate_model(opt):
    assert opt.model in ['resnet', 'resnetl', 'resnext', 'c3d']

    if opt.model == 'resnet':
        assert opt.model_depth in [10]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
    elif opt.model == 'resnetl':
        assert opt.model_depth in [10]

        from models.resnetl import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnetl.resnetl10(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)

    elif opt.model == 'resnext':
        assert opt.model_depth in [101]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'c3d':
        assert opt.model_depth in [10]

        from models.c3d import get_fine_tuning_parameters
        if opt.model_depth == 10:

            model = c3d.c3d_v1(sample_size=opt.sample_size,
                               sample_duration=opt.sample_duration,
                               num_classes=opt.n_classes)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']
            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'c3d':  # CHECK HERE
                model.module.fc = nn.Linear(model.module.fc[0].in_features,
                                            opt.n_finetune_classes)
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
            model.module.fc = model.module.fc.cuda()

        if opt.modality == 'RGB' and opt.model != 'c3d':
            print("[INFO]: RGB model is used for init model")
            model = _modify_first_conv_layer(
                model, 3, 3)  ##### Check models trained (3,7,7) or (7,7,7)
        elif opt.modality == 'Depth':
            print(
                "[INFO]: Converting the pretrained model to Depth init model")
            model = _construct_depth_model(model)
            print("[INFO]: Done. Flow model ready.")
        elif opt.modality == 'RGB-D':
            print(
                "[INFO]: Converting the pretrained model to RGB+D init model")
            model = _construct_rgbdepth_model(model)
            print("[INFO]: Done. RGB-D model ready.")

        modules = list(model.modules())
        first_conv_idx = list(
            filter(lambda x: isinstance(modules[x], nn.Conv3d),
                   list(range(len(modules)))))[0]
        conv_layer = modules[first_conv_idx]
        if conv_layer.kernel_size[0] > opt.sample_duration:
            model = _modify_first_conv_layer(model,
                                             int(opt.sample_duration / 2), 1)

        parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
        return model, parameters

    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']
            model.load_state_dict(pretrain['state_dict'])

        if opt.modality == 'RGB' and opt.model != 'c3d':
            print("[INFO]: RGB model is used for init model")
            model = _modify_first_conv_layer(model, 3, 3)
        elif opt.modality == 'Depth':
            print(
                "[INFO]: Converting the pretrained model to Depth init model")
            model = _construct_depth_model(model)
            print("[INFO]: Deoth model ready.")
        elif opt.modality == 'RGB-D':
            print(
                "[INFO]: Converting the pretrained model to RGB-D init model")
            model = _construct_rgbdepth_model(model)
            print("[INFO]: Done. RGB-D model ready.")

        modules = list(model.modules())
        first_conv_idx = list(
            filter(lambda x: isinstance(modules[x], nn.Conv3d),
                   list(range(len(modules)))))[0]
        conv_layer = modules[first_conv_idx]
        if conv_layer.kernel_size[0] > opt.sample_duration:
            print("[INFO]: RGB model is used for init model")
            model = _modify_first_conv_layer(model,
                                             int(opt.sample_duration / 2), 1)

        if opt.model == 'c3d':  # CHECK HERE
            model.fc = nn.Linear(model.fc[0].in_features,
                                 model.fc[0].out_features)
        else:
            model.fc = nn.Linear(model.fc.in_features, opt.n_finetune_classes)

        parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
        return model, parameters
コード例 #12
0
ファイル: model.py プロジェクト: MartinSavc/MedicalNet
def generate_model(opt):
    assert opt.model in ['resnet']

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        if opt.model_depth == 10:
            model = resnet.resnet10(sample_input_W=opt.input_W,
                                    sample_input_H=opt.input_H,
                                    sample_input_D=opt.input_D,
                                    shortcut_type=opt.resnet_shortcut,
                                    no_cuda=opt.no_cuda,
                                    num_seg_classes=opt.n_seg_classes)
        elif opt.model_depth == 18:
            model = resnet.resnet18(sample_input_W=opt.input_W,
                                    sample_input_H=opt.input_H,
                                    sample_input_D=opt.input_D,
                                    shortcut_type=opt.resnet_shortcut,
                                    no_cuda=opt.no_cuda,
                                    num_seg_classes=opt.n_seg_classes)
        elif opt.model_depth == 34:
            model = resnet.resnet34(sample_input_W=opt.input_W,
                                    sample_input_H=opt.input_H,
                                    sample_input_D=opt.input_D,
                                    shortcut_type=opt.resnet_shortcut,
                                    no_cuda=opt.no_cuda,
                                    num_seg_classes=opt.n_seg_classes)
        elif opt.model_depth == 50:
            model = resnet.resnet50(sample_input_W=opt.input_W,
                                    sample_input_H=opt.input_H,
                                    sample_input_D=opt.input_D,
                                    shortcut_type=opt.resnet_shortcut,
                                    no_cuda=opt.no_cuda,
                                    num_seg_classes=opt.n_seg_classes)
        elif opt.model_depth == 101:
            model = resnet.resnet101(sample_input_W=opt.input_W,
                                     sample_input_H=opt.input_H,
                                     sample_input_D=opt.input_D,
                                     shortcut_type=opt.resnet_shortcut,
                                     no_cuda=opt.no_cuda,
                                     num_seg_classes=opt.n_seg_classes)
        elif opt.model_depth == 152:
            model = resnet.resnet152(sample_input_W=opt.input_W,
                                     sample_input_H=opt.input_H,
                                     sample_input_D=opt.input_D,
                                     shortcut_type=opt.resnet_shortcut,
                                     no_cuda=opt.no_cuda,
                                     num_seg_classes=opt.n_seg_classes)
        elif opt.model_depth == 200:
            model = resnet.resnet200(sample_input_W=opt.input_W,
                                     sample_input_H=opt.input_H,
                                     sample_input_D=opt.input_D,
                                     shortcut_type=opt.resnet_shortcut,
                                     no_cuda=opt.no_cuda,
                                     num_seg_classes=opt.n_seg_classes)

    if not opt.no_cuda:
        if len(opt.gpu_id) > 1:
            model = model.cuda()
            model = nn.DataParallel(model, device_ids=opt.gpu_id)
            net_dict = model.state_dict()
        else:
            import os
            os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu_id[0])
            model = model.cuda()
            model = nn.DataParallel(model, device_ids=None)
            net_dict = model.state_dict()
    else:
        net_dict = model.state_dict()

    # load pretrain
    if opt.phase != 'test':
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            pretrain_dict = {
                k: v
                for k, v in pretrain['state_dict'].items()
                if k in net_dict.keys()
            }

            net_dict.update(pretrain_dict)
            model.load_state_dict(net_dict)

        new_parameters = []
        for pname, p in model.named_parameters():
            for layer_name in opt.new_layer_names:
                if pname.find(layer_name) >= 0:
                    new_parameters.append(p)
                    break

        new_parameters_id = list(map(id, new_parameters))
        base_parameters = list(
            filter(lambda p: id(p) not in new_parameters_id,
                   model.parameters()))
        parameters = {
            'base_parameters': base_parameters,
            'new_parameters': new_parameters
        }

        return model, parameters

    return model, model.parameters()
コード例 #13
0
ファイル: model.py プロジェクト: TudorAnca/Real-time-GesRec
def generate_model(opt):
    assert opt.model in [
        'c3d', 'squeezenet', 'mobilenet', 'resnext', 'resnet', 'resnetl',
        'shufflenet', 'mobilenetv2', 'shufflenetv2'
    ]

    if opt.model == 'c3d':
        from models.c3d import get_fine_tuning_parameters
        model = c3d.get_model(num_classes=opt.n_classes,
                              sample_size=opt.sample_size,
                              sample_duration=opt.sample_duration)
    elif opt.model == 'squeezenet':
        from models.squeezenet import get_fine_tuning_parameters
        model = squeezenet.get_model(version=opt.version,
                                     num_classes=opt.n_classes,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'shufflenet':
        from models.shufflenet import get_fine_tuning_parameters
        model = shufflenet.get_model(groups=opt.groups,
                                     width_mult=opt.width_mult,
                                     num_classes=opt.n_classes)
    elif opt.model == 'shufflenetv2':
        from models.shufflenetv2 import get_fine_tuning_parameters
        model = shufflenetv2.get_model(num_classes=opt.n_classes,
                                       sample_size=opt.sample_size,
                                       width_mult=opt.width_mult)
    elif opt.model == 'mobilenet':
        from models.mobilenet import get_fine_tuning_parameters
        model = mobilenet.get_model(num_classes=opt.n_classes,
                                    sample_size=opt.sample_size,
                                    width_mult=opt.width_mult)
    elif opt.model == 'mobilenetv2':
        from models.mobilenetv2 import get_fine_tuning_parameters
        model = mobilenetv2.get_model(num_classes=opt.n_classes,
                                      sample_size=opt.sample_size,
                                      width_mult=opt.width_mult)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]
        from models.resnext import get_fine_tuning_parameters
        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'resnetl':
        assert opt.model_depth in [10]

        from models.resnetl import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnetl.resnetl10(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]
        from models.resnet import get_fine_tuning_parameters
        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)
        pytorch_total_params = sum(p.numel() for p in model.parameters()
                                   if p.requires_grad)
        print("Total number of trainable parameters: ", pytorch_total_params)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path,
                                  map_location=torch.device('cpu'))
            # print(opt.arch)
            # print(pretrain['arch'])
            # assert opt.arch == pretrain['arch']
            model.load_state_dict(pretrain['state_dict'])

            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.5),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
                model.module.classifier = model.module.classifier.cuda()
            elif opt.model == 'squeezenet':
                model.module.classifier = nn.Sequential(
                    nn.Dropout(p=0.5),
                    nn.Conv3d(model.module.classifier[1].in_channels,
                              opt.n_finetune_classes,
                              kernel_size=1), nn.ReLU(inplace=True),
                    nn.AvgPool3d((1, 4, 4), stride=1))
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            # model = _modify_first_conv_layer(model)
            # model = model.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_portion)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']
            model.load_state_dict(pretrain['state_dict'])

            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.9),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
            elif opt.model == 'squeezenet':
                model.module.classifier = nn.Sequential(
                    nn.Dropout(p=0.5),
                    nn.Conv3d(model.module.classifier[1].in_channels,
                              opt.n_finetune_classes,
                              kernel_size=1), nn.ReLU(inplace=True),
                    nn.AvgPool3d((1, 4, 4), stride=1))
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
コード例 #14
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet','senet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                k=opt.wide_resnet_k,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)

    elif opt.model == 'senet':
         assert opt.model_depth in [50,101,152,154,5032,10132]
         if opt.model_depth == 50:
             model = senet.se_resnet50(num_classes = opt.n_classes,  pretrained = None)
         elif opt.model_depth == 101:
             model = senet.se_resnet101(num_classes = opt.n_classes, pretrained = None)
         elif opt.model_depth == 152:
             model = senet.se_resnet152(num_classes = opt.n_classes, pretrained = None)
         elif opt.model_depth == 154:
             model = senet.senet154(num_classes = opt.n_classes, pretrained = None)
         elif opt.model_depth == 5032:
             model = senet.resnext50_32x4d(num_classes = opt.n_classes, pretrained = None)
         elif opt.model_depth == 10132:
             model = senet.se_resnext101_32x4d(num_classes = opt.n_classes, pretrained = None)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            #assert opt.arch == pretrain['arch']
            #model.load_state_dict(pretrain['state_dict'])
            pretrained_dict = pretrain['state_dict']
            model_dict = model.state_dict()
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if k.find("module.fc") == -1}
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

            #model.load_state_dict(model_dict,strict=False)

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features, opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            elif opt.model == "senet":
                model.module.last_linear = nn.Linear(model.module.last_linear.in_features, opt.n_finetune_classes)
                model.module.last_linear = model.module.last_linear.cuda()
                return model, model.parameters()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(
                    model.classifier.in_features, opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                            opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
コード例 #15
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'resnet_skeleton', 'preresnet', 'wideresnet', 'resnext',
        'densenet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'resnet_skeleton':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet_skeleton import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet_skeleton.resnet_skeleton10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet_skeleton.resnet_skeleton18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet_skeleton.resnet_skeleton34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet_skeleton.resnet_skeleton50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet_skeleton.resnet_skeleton101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet_skeleton.resnet_skeleton152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet_skeleton.resnet_skeleton200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        if opt.cuda_id is None:
            model = model.cuda()
        else:
            model = model.cuda(opt.cuda_id)
        # model = nn.DataParallel(model, device_ids=None)
        if opt.cuda_id is None:
            model = nn.DataParallel(model, device_ids=None)
        else:
            model = nn.DataParallel(model, device_ids=[opt.cuda_id])

        if opt.pretrain_path:
            print('    loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)

            if opt.model == 'resnet_skeleton':
                pretrained_dict = pretrain['state_dict']
                model_dict = model.state_dict()
                # print('----------------')
                # for k, v in pretrained_dict.items():
                #     if k in model_dict:
                #         print(k)
                # print('----------------')

                # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
                pretrained_dict = {
                    k: v
                    for k, v in pretrained_dict.items()
                    if k in model_dict and 'fc' not in k
                }  ## for concatenate
                model_dict.update(pretrained_dict)
                model.load_state_dict(model_dict)
            else:
                assert opt.arch == pretrain['arch']
                model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features,
                    opt.n_finetune_classes)
                if opt.cuda_id is None:
                    model.module.classifier = model.module.classifier.cuda()
                else:
                    model.module.classifier = model.module.classifier.cuda(
                        opt.cuda_id)
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                if opt.cuda_id is None:
                    model.module.fc = model.module.fc.cuda()
                else:
                    model.module.fc = model.module.fc.cuda(opt.cuda_id)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
コード例 #16
0
def generate_model(opt):
    assert opt.mode in ['score', 'feature']
    if opt.mode == 'score':
        last_fc = True
    elif opt.mode == 'feature':
        last_fc = False

    assert opt.model_name in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    if opt.model_name == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
    elif opt.model_name == 'wideresnet':
        assert opt.model_depth in [50]

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
    elif opt.model_name == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration,
                                      last_fc=last_fc)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration,
                                      last_fc=last_fc)
    elif opt.model_name == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                last_fc=last_fc)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                last_fc=last_fc)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                last_fc=last_fc)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                last_fc=last_fc)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                last_fc=last_fc)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                last_fc=last_fc)
    elif opt.model_name == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         last_fc=last_fc)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

    return model
コード例 #17
0
ファイル: model.py プロジェクト: xxy19404/3D-ResNet-Pytorch
def generate_model(opt):
    assert opt.mode in ['score', 'feature']
    if opt.mode == 'score':
        last_fc = True
    elif opt.mode == 'feature':
        last_fc = False

    assert opt.model_name in ['resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet']

    if opt.model_name == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
        polices = resnet.get_fine_tuning_parameters(model, opt.ft_begin_index)
    elif opt.model_name == 'wideresnet':
        assert opt.model_depth in [50]

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
    elif opt.model_name == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                      last_fc=last_fc)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                      last_fc=last_fc)
    elif opt.model_name == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                            sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                            last_fc=last_fc)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                            sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                            last_fc=last_fc)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                            sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                            last_fc=last_fc)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                             sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                             last_fc=last_fc)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                             sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                             last_fc=last_fc)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                             sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                             last_fc=last_fc)
    elif opt.model_name == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
    return model, polices
コード例 #18
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'resnet_AE', 'resnet_mask', 'resnet_comp', 'unet', 'icnet',
        'icnet_res', 'icnet_res_2D', 'icnet_res_2Dt', 'icnet_DBI',
        'icnet_deep', 'icnet_deep_gate', 'icnet_deep_gate_2step'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)

    elif opt.model == 'resnet_AE':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet_AE import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = resnet_AE.resnet18(num_classes=opt.n_classes,
                                       shortcut_type=opt.resnet_shortcut,
                                       sample_size=opt.sample_size,
                                       sample_duration=opt.sample_duration,
                                       is_gray=opt.is_gray,
                                       opt=opt)
        elif opt.model_depth == 34:
            model = resnet_AE.resnet34(num_classes=opt.n_classes,
                                       shortcut_type=opt.resnet_shortcut,
                                       sample_size=opt.sample_size,
                                       sample_duration=opt.sample_duration,
                                       is_gray=opt.is_gray,
                                       opt=opt)
        elif opt.model_depth == 50:
            model = resnet_AE.resnet50(num_classes=opt.n_classes,
                                       shortcut_type=opt.resnet_shortcut,
                                       sample_size=opt.sample_size,
                                       sample_duration=opt.sample_duration,
                                       is_gray=opt.is_gray,
                                       opt=opt)

    elif opt.model == 'resnet_mask':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet_mask import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = resnet_mask.resnet18(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         is_gray=opt.is_gray,
                                         opt=opt)
        elif opt.model_depth == 34:
            model = resnet_mask.resnet34(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         is_gray=opt.is_gray,
                                         opt=opt)
        elif opt.model_depth == 50:
            model = resnet_mask.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         is_gray=opt.is_gray,
                                         opt=opt)

    elif opt.model == 'resnet_comp':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet_comp import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = resnet_comp.resnet18(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         is_gray=opt.is_gray,
                                         opt=opt)
        elif opt.model_depth == 34:
            model = resnet_comp.resnet34(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         is_gray=opt.is_gray,
                                         opt=opt)
        elif opt.model_depth == 50:
            model = resnet_comp.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         is_gray=opt.is_gray,
                                         opt=opt)
    elif opt.model == 'unet':
        model = unet_mask.UNet3D(opt=opt)

    elif opt.model == 'icnet':
        model = icnet_mask.ICNet3D(opt=opt)
    elif opt.model == 'icnet_res':
        model = icnet_res.ICNetResidual3D(opt=opt)
    elif opt.model == 'icnet_res_2D':
        model = icnet_res.ICNetResidual2D(opt=opt)
    elif opt.model == 'icnet_res_2Dt':
        model = icnet_res.ICNetResidual2Dt(opt=opt)
    elif opt.model == 'icnet_DBI':
        model = icnet_res.ICNetResidual_DBI(opt=opt)
    elif opt.model == 'icnet_deep':
        model = icnet_res.ICNetDeep(opt=opt)
    elif opt.model == 'icnet_deep_gate':
        model = icnet_res.ICNetDeepGate(opt=opt)
    elif opt.model == 'icnet_deep_gate_2step':
        model = icnet_res.ICNetDeepGate2step(opt=opt)
    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            print('loading from', pretrain['arch'])

            child_dict = model.state_dict()
            if opt.two_step and opt.test:
                parent_list = pretrain['state_dict_1'].keys()
            else:
                parent_list = pretrain['state_dict'].keys()

            print('Not loaded :')
            parent_dict = {}
            for chi, _ in child_dict.items():
                # pdb.set_trace()
                # if ('coarse' in chi):
                # chi_ori = chi
                # chi = 'module.' + ".".join(chi_ori.split('.')[2:])

                if chi in parent_list:
                    if opt.two_step and opt.test:
                        parent_dict[chi] = pretrain['state_dict_1'][chi]
                    else:
                        parent_dict[chi] = pretrain['state_dict'][chi]
                else:
                    print(chi)
            print('length :', len(parent_dict.keys()))
            child_dict.update(parent_dict)
            model.load_state_dict(child_dict)

            if not opt.is_AE:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
            return model, model.parameters()

    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert pretrain['arch'] in ['resnet', 'resnet_AE']

            model.load_state_dict(pretrain['state_dict'])

            model.module.fc = nn.Linear(model.module.fc.in_features,
                                        opt.n_finetune_classes)
            if not opt.no_cuda:
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
コード例 #19
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    ###################################################################
    # ResNet
    ###################################################################
    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
            
    ###################################################################
    # Wider ResNet
    ###################################################################
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                k=opt.wide_resnet_k,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
            
    ###################################################################
    # ResNext
    ###################################################################
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
            
    ###################################################################
    # Pre-ResNet
    ###################################################################
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
            
    ###################################################################
    # DenseNet
    ###################################################################
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    
    ###################################################################
    # Finalizing the model
    ###################################################################
    if not opt.no_cuda:
        
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=opt.device_ids)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']  # ensure that pretrain model is the same architecture

            model.load_state_dict(pretrain['state_dict'])
            
            # change the fc layer output size
            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features, opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()
            
            # 
            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(
                    model.classifier.in_features, opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                            opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
コード例 #20
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        args = {
            "num_classes": opt.n_classes,
            "shortcut_type": opt.resnet_shortcut,
            "sample_size": opt.sample_size,
            "sample_duration": opt.sample_duration
        }

        if opt.model_depth == 10:
            model = resnet.resnet10(**args)
        elif opt.model_depth == 18:
            model = resnet.resnet18(**args)
        elif opt.model_depth == 34:
            model = resnet.resnet34(**args)
        elif opt.model_depth == 50:
            model = resnet.resnet50(**args)
        elif opt.model_depth == 101:
            model = resnet.resnet101(**args)
        elif opt.model_depth == 152:
            model = resnet.resnet152(**args)
        elif opt.model_depth == 200:
            model = resnet.resnet200(**args)

    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)

    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        args = {
            "num_classes": opt.n_classes,
            "shortcut_type": opt.resnet_shortcut,
            "cardinality": opt.resnext_cardinality,
            "sample_size": opt.sample_size,
            "sample_duration": opt.sample_duration
        }

        if opt.model_depth == 50:
            model = resnext.resnet50(**args)
        elif opt.model_depth == 101:
            model = resnext.resnet101(**args)
        elif opt.model_depth == 152:
            model = resnext.resnet152(**args)

    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        args = {
            "num_classes": opt.n_classes,
            "shortcut_type": opt.resnet_shortcut,
            "sample_size": opt.sample_size,
            "sample_duration": opt.sample_duration
        }

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(**args)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(**args)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(**args)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(**args)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(**args)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(**args)

    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        args = {
            "num_classes": opt.n_classes,
            "sample_size": opt.sample_size,
            "sample_duration": opt.sample_duration
        }

        if opt.model_depth == 121:
            model = densenet.densenet121(**args)
        elif opt.model_depth == 169:
            model = densenet.densenet169(**args)
        elif opt.model_depth == 201:
            model = densenet.densenet201(**args)
        elif opt.model_depth == 264:
            model = densenet.densenet264(**args)

    if opt.no_cuda:
        device = 'cpu'
    else:
        device = 'cuda'
        model = model.to(device)
        model = nn.DataParallel(model, device_ids=None)

    if opt.pretrain_path:
        print('loading pretrained model {}'.format(opt.pretrain_path))
        pretrain = torch.load(opt.pretrain_path, map_location=device)
        assert opt.arch == pretrain['arch']

        model.load_state_dict(pretrain['state_dict'])

        if opt.model == 'densenet':
            model.module.classifier = nn.Linear(
                model.module.classifier.in_features, opt.n_finetune_classes)
            model.module.classifier = model.module.classifier.to(device)
        else:
            model.module.fc = nn.Linear(model.module.fc.in_features,
                                        opt.n_finetune_classes)
            model.module.fc = model.module.fc.to(device)

        parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
        return model, parameters

    return model, model.parameters()
コード例 #21
0
def generate_C3D_model(opt):
    assert opt.mode in ['score', 'feature']
    if opt.mode == 'score':
        last_fc = True
    elif opt.mode == 'feature':
        last_fc = False

    assert opt.c3d_model_name in ['resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet']

    if opt.c3d_model_name == 'resnet':
        assert opt.c3d_model_depth in [10, 18, 34, 50, 101, 152, 200]

        if opt.c3d_model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.c3d_model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.c3d_model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.c3d_model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.c3d_model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
        elif opt.c3d_model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
        elif opt.c3d_model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
    elif opt.c3d_model_name == 'wideresnet':
        assert opt.c3d_model_depth in [50]

        if opt.c3d_model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
    elif opt.c3d_model_name == 'resnext':
        assert opt.c3d_model_depth in [50, 101, 152]

        if opt.c3d_model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
        elif opt.c3d_model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                      last_fc=last_fc)
        elif opt.c3d_model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                      last_fc=last_fc)
    elif opt.c3d_model_name == 'preresnet':
        assert opt.c3d_model_depth in [18, 34, 50, 101, 152, 200]

        if opt.c3d_model_depth == 18:
            model = pre_act_resnet.resnet18(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                            sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                            last_fc=last_fc)
        elif opt.c3d_model_depth == 34:
            model = pre_act_resnet.resnet34(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                            sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                            last_fc=last_fc)
        elif opt.c3d_model_depth == 50:
            model = pre_act_resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                            sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                            last_fc=last_fc)
        elif opt.c3d_model_depth == 101:
            model = pre_act_resnet.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                             sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                             last_fc=last_fc)
        elif opt.c3d_model_depth == 152:
            model = pre_act_resnet.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                             sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                             last_fc=last_fc)
        elif opt.c3d_model_depth == 200:
            model = pre_act_resnet.resnet200(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                             sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                             last_fc=last_fc)
    elif opt.c3d_model_name == 'densenet':
        assert opt.c3d_model_depth in [121, 169, 201, 264]

        if opt.c3d_model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
        elif opt.c3d_model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
        elif opt.c3d_model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
        elif opt.c3d_model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                         last_fc=last_fc)

    # print(model)
    print('loading c3d model from: {}'.format(opt.c3d_model_checkpoint))
    model_data = torch.load(opt.c3d_model_checkpoint)
    print(model_data['arch'])
    assert opt.arch == model_data['arch']

    model_state_dict = {}
    for k, v in model_data['state_dict'].items():
        model_state_dict[k[k.index('.') + 1:]] = v

    model.load_state_dict(model_state_dict)

    if not opt.no_cuda:
        model = model.to(opt.device)
        # model = nn.DataParallel(model, device_ids=None)

    # print(model)
    return model
コード例 #22
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    model_type=opt.model_type)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     model_type=opt.model_type)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     model_type=opt.model_type)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     model_type=opt.model_type)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        import os
        # os.environ['CUDA_VISIBLE_DEVICES'] = f'{opt.cuda_id}'
        model = model.cuda(device=opt.cuda_id)
        model = nn.DataParallel(model, device_ids=[0])  # CUDA change

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            print(pretrain['arch'])
            arch = f'{opt.model}-{opt.model_depth}'
            # arch = opt.model + '-' + opt.model_depth
            print(arch)
            assert arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features,
                    opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda(
                    device=opt.cuda_id)
            # elif opt.use_quadriplet:
            #     model = EmbeddingModel(model, opt.n_finetune_classes, not opt.no_cuda, opt.cuda_id)
            else:
                model.module.fc = nn.Sequential(
                    nn.Dropout(0.4),
                    nn.Linear(model.module.fc.in_features, 512), nn.ReLU6(),
                    nn.Dropout(0.4), nn.Linear(512, 128), nn.ReLU6(),
                    nn.Linear(128,
                              opt.n_finetune_classes)).cuda(device=opt.cuda_id)
                # model.module.fc = nn.Linear(model.module.fc.in_features,
                #                             opt.n_finetune_classes)

                # model.module.fc = model.module.fc.cuda(device=opt.cuda_id)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            print(len(list(parameters)), 'params to fine tune')
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)

            return model, parameters

    return model, model.parameters()
コード例 #23
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features,
                    opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            # strip off the 'module.' for each module; this get's added when a model is saved using nn.DataParallel
            pretrain['state_dict'] = {
                k[7:]: v
                for k, v in pretrain['state_dict'].items()
            }
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
コード例 #24
0
def get_model(config):

    assert config.model in [
        'i3d', 'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]
    print('Initializing {} model (num_classes={})...'.format(
        config.model, config.num_classes))

    if config.model == 'i3d':

        from models.i3d import get_fine_tuning_parameters

        model = InceptionI3D(num_classes=config.num_classes,
                             spatial_squeeze=True,
                             final_endpoint='logits',
                             in_channels=3,
                             dropout_keep_prob=config.dropout_keep_prob)

    elif config.model == 'resnet':

        assert config.model_depth in [10, 18, 34, 50, 101, 152, 200]
        from models.resnet import get_fine_tuning_parameters

        if config.model_depth == 10:

            model = resnet.resnet10(num_classes=config.num_classes,
                                    shortcut_type=config.resnet_shortcut,
                                    spatial_size=config.spatial_size,
                                    sample_duration=config.sample_duration)

        elif config.model_depth == 18:

            model = resnet.resnet18(num_classes=config.num_classes,
                                    shortcut_type=config.resnet_shortcut,
                                    spatial_size=config.spatial_size,
                                    sample_duration=config.sample_duration)

        elif config.model_depth == 34:

            model = resnet.resnet34(num_classes=config.num_classes,
                                    shortcut_type=config.resnet_shortcut,
                                    spatial_size=config.spatial_size,
                                    sample_duration=config.sample_duration)

        elif config.model_depth == 50:

            model = resnet.resnet50(num_classes=config.num_classes,
                                    shortcut_type=config.resnet_shortcut,
                                    spatial_size=config.spatial_size,
                                    sample_duration=config.sample_duration)

        elif config.model_depth == 101:

            model = resnet.resnet101(num_classes=config.num_classes,
                                     shortcut_type=config.resnet_shortcut,
                                     spatial_size=config.spatial_size,
                                     sample_duration=config.sample_duration)

        elif config.model_depth == 152:

            model = resnet.resnet152(num_classes=config.num_classes,
                                     shortcut_type=config.resnet_shortcut,
                                     spatial_size=config.spatial_size,
                                     sample_duration=config.sample_duration)

        elif config.model_depth == 200:

            model = resnet.resnet200(num_classes=config.num_classes,
                                     shortcut_type=config.resnet_shortcut,
                                     spatial_size=config.spatial_size,
                                     sample_duration=config.sample_duration)

    elif config.model == 'wideresnet':

        assert config.model_depth in [50]
        from models.wide_resnet import get_fine_tuning_parameters

        if config.model_depth == 50:
            model = wide_resnet.resnet50(
                num_classes=config.num_classes,
                shortcut_type=config.resnet_shortcut,
                k=config.wide_resnet_k,
                spatial_size=config.spatial_size,
                sample_duration=config.sample_duration)

    elif config.model == 'resnext':

        assert config.model_depth in [50, 101, 152]
        from models.resnext import get_fine_tuning_parameters

        if config.model_depth == 50:
            model = resnext.resnet50(num_classes=config.num_classes,
                                     shortcut_type=config.resnet_shortcut,
                                     cardinality=config.resnext_cardinality,
                                     spatial_size=config.spatial_size,
                                     sample_duration=config.sample_duration)
        elif config.model_depth == 101:
            model = resnext.resnet101(num_classes=config.num_classes,
                                      shortcut_type=config.resnet_shortcut,
                                      cardinality=config.resnext_cardinality,
                                      spatial_size=config.spatial_size,
                                      sample_duration=config.sample_duration)
        elif config.model_depth == 152:
            model = resnext.resnet152(num_classes=config.num_classes,
                                      shortcut_type=config.resnet_shortcut,
                                      cardinality=config.resnext_cardinality,
                                      spatial_size=config.spatial_size,
                                      sample_duration=config.sample_duration)

    elif config.model == 'densenet':

        assert config.model_depth in [121, 169, 201, 264]
        from models.densenet import get_fine_tuning_parameters

        if config.model_depth == 121:
            model = densenet.densenet121(
                num_classes=config.num_classes,
                spatial_size=config.spatial_size,
                sample_duration=config.sample_duration)
        elif config.model_depth == 169:
            model = densenet.densenet169(
                num_classes=config.num_classes,
                spatial_size=config.spatial_size,
                sample_duration=config.sample_duration)
        elif config.model_depth == 201:
            model = densenet.densenet201(
                num_classes=config.num_classes,
                spatial_size=config.spatial_size,
                sample_duration=config.sample_duration)
        elif config.model_depth == 264:
            model = densenet.densenet264(
                num_classes=config.num_classes,
                spatial_size=config.spatial_size,
                sample_duration=config.sample_duration)

    if 'cuda' in config.device:

        print('Moving model to CUDA device...')
        # Move model to the GPU
        model = model.cuda()

        if config.model != 'i3d':
            model = nn.DataParallel(model, device_ids=None)

        if config.checkpoint_path:

            print('Loading pretrained model {}'.format(config.checkpoint_path))
            assert os.path.isfile(config.checkpoint_path)

            checkpoint = torch.load(config.checkpoint_path)
            if config.model == 'i3d':
                pretrained_weights = checkpoint
            else:
                pretrained_weights = checkpoint['state_dict']

            model.load_state_dict(pretrained_weights)

            # Setup finetuning layer for different number of classes
            # Note: the DataParallel adds 'module' dict to complicate things...
            print('Replacing model logits with {} output classes.'.format(
                config.finetune_num_classes))

            if config.model == 'i3d':
                model.replace_logits(config.finetune_num_classes)
            elif config.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features,
                    config.finetune_num_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            config.finetune_num_classes)
                model.module.fc = model.module.fc.cuda()

            # Setup which layers to train
            assert config.model in (
                'i3d', 'resnet'), 'finetune params not implemented...'
            finetune_criterion = config.finetune_prefixes if config.model in (
                'i3d', 'resnet') else config.finetune_begin_index
            parameters_to_train = get_fine_tuning_parameters(
                model, finetune_criterion)

            return model, parameters_to_train
    else:
        raise ValueError('CPU training not supported.')

    return model, model.parameters()
コード例 #25
0
def get_model(args):

    assert args.model in [
        'derpnet', 'alexnet', 'resnet', 'vgg', 'vgg_attn', 'inception'
    ]

    if args.model == 'alexnet':
        model = alexnet.alexnet(pretrained=args.pretrained,
                                n_channels=args.n_channels,
                                num_classes=args.n_classes)
    elif args.model == 'inception':
        model = inception.inception_v3(pretrained=args.pretrained,
                                       aux_logits=False,
                                       progress=True,
                                       num_classes=args.n_classes)
    elif args.model == 'vgg':
        assert args.model_depth in [11, 13, 16, 19]

        if args.model_depth == 11:
            model = vgg.vgg11_bn(pretrained=args.pretrained,
                                 progress=True,
                                 num_classes=args.n_classes)
        if args.model_depth == 13:
            model = vgg.vgg13_bn(pretrained=args.pretrained,
                                 progress=True,
                                 num_classes=args.n_classes)
        if args.model_depth == 16:
            model = vgg.vgg16_bn(pretrained=args.pretrained,
                                 progress=True,
                                 num_classes=args.n_classes)
        if args.model_depth == 19:
            model = vgg.vgg19(pretrained=args.pretrained,
                              progress=True,
                              num_classes=args.n_classes)

    elif args.model == 'vgg_attn':
        assert args.model_depth in [11, 13, 16, 19]

        if args.model_depth == 11:
            model = vgg_attn.vgg11_bn(pretrained=args.pretrained,
                                      progress=True,
                                      num_classes=args.n_classes)
        if args.model_depth == 13:
            model = vgg_attn.vgg11_bn(pretrained=args.pretrained,
                                      progress=True,
                                      num_classes=args.n_classes)
        if args.model_depth == 16:
            model = vgg_attn.vgg11_bn(pretrained=args.pretrained,
                                      progress=True,
                                      num_classes=args.n_classes)
        if args.model_depth == 19:
            model = vgg_attn.vgg11_bn(pretrained=args.pretrained,
                                      progress=True,
                                      num_classes=args.n_classes)

    elif args.model == 'derpnet':
        model = derp_net.Net(n_channels=args.n_channels,
                             num_classes=args.n_classes)

    elif args.model == 'resnet':
        assert args.model_depth in [10, 18, 34, 50, 101, 152, 200]

        if args.model_depth == 10:
            model = resnet.resnet10(pretrained=args.pretrained,
                                    num_classes=args.n_classes)
        elif args.model_depth == 18:
            model = resnet.resnet18(pretrained=args.pretrained,
                                    num_classes=args.n_classes)
        elif args.model_depth == 34:
            model = resnet.resnet34(pretrained=args.pretrained,
                                    num_classes=args.n_classes)
        elif args.model_depth == 50:
            model = resnet.resnet50(pretrained=args.pretrained,
                                    num_classes=args.n_classes)
        elif args.model_depth == 101:
            model = resnet.resnet101(pretrained=args.pretrained,
                                     num_classes=args.n_classes)
        elif args.model_depth == 152:
            model = resnet.resnet152(pretrained=args.pretrained,
                                     num_classes=args.n_classes)
        elif args.model_depth == 200:
            model = resnet.resnet200(pretrained=args.pretrained,
                                     num_classes=args.n_classes)

    if args.pretrained and args.pretrain_path and not args.model == 'alexnet' and not args.model == 'vgg' and not args.model == 'resnet':

        print('loading pretrained model {}'.format(args.pretrain_path))
        pretrain = torch.load(args.pretrain_path)
        assert args.arch == pretrain['arch']

        # here all the magic happens: need to pick the parameters which will be adjusted during training
        # the rest of the params will be frozen
        pretrain_dict = {
            key[7:]: value
            for key, value in pretrain['state_dict'].items()
            if key[7:9] != 'fc'
        }
        from collections import OrderedDict
        pretrain_dict = OrderedDict(pretrain_dict)

        # https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance
        import types
        model.load_state_dict = types.MethodType(load_my_state_dict, model)

        old_dict = copy.deepcopy(
            model.state_dict())  # normal copy() just gives a reference
        model.load_state_dict(pretrain_dict)
        new_dict = model.state_dict()

        num_features = model.fc.in_features
        if args.model == 'densenet':
            model.classifier = nn.Linear(num_features, args.n_classes)
        else:
            #model.fc = nn.Sequential(nn.Linear(num_features, 1028), nn.ReLU(), nn.Dropout(0.5), nn.Linear(1028, args.n_finetune_classes))
            model.fc = nn.Linear(num_features, args.n_classes)

        # parameters = get_fine_tuning_parameters(model, args.ft_begin_index)
        parameters = model.parameters()  # fine-tunining EVERYTHIIIIIANG
        # parameters = model.fc.parameters()  # fine-tunining ONLY FC layer
        return model, parameters

    return model, model.parameters()
コード例 #26
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'resnext_fa', 'densenet', 'p3d'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                k=opt.wide_resnet_k,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'p3d':
        assert opt.model_depth in [50, 101, 152]

        if opt.model_depth == 50:
            model = p3d.P3D63(num_classes=opt.n_classes)
        elif opt.model_depth == 101:
            model = p3d.P3D131(num_classes=opt.n_classes)
        elif opt.model_depth == 152:
            model = p3d.P3D199(num_classes=opt.n_classes)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext_fa import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)

    elif opt.model == 'resnext_fa':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext_fa import get_fine_tuning_parameters, get_fine_tuning_parameters_fa

        if opt.model_depth == 50:
            model = resnext_fa.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext_fa.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext_fa.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)


    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            #pdb.set_trace();
            #assert opt.arch == pretrain['arch']

            model_dict = model.state_dict();
            #pdb.set_trace();
            model_dict.update(pretrain['state_dict']);
            model.load_state_dict(model_dict);
            #model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features, opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            # do not need to add new fc layer when finetuning model has the same class num
            elif (opt.n_classes != opt.n_finetune_classes):
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            if (opt.model == 'resnext_fa'):
                parameters = get_fine_tuning_parameters_fa(model, opt.learning_rate)
            else:
                parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(
                    model.classifier.in_features, opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                            opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
コード例 #27
0
ファイル: model.py プロジェクト: tomzhaojinxin/video_cluster
def generate_model(opt):
    assert opt.model in ['resnet']

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]
        input_chan = 3

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration, \
                input_chan=input_chan)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration, \
                input_chan=input_chan)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration, \
                input_chan=input_chan)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration, \
                input_chan=input_chan)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            if 'arch' in pretrain:
                assert opt.arch == pretrain['arch']
                model.load_state_dict(pretrain['state_dict'])
            else:
                if "state_dict" in pretrain.keys():
                    model.module.load_state_dict(pretrain['state_dict'])
                else:
                    model.module.fc = nn.Linear(model.module.fc.in_features,
                                                128)
                    model.load_state_dict(pretrain['model_state_dict'])

            model.module.fc = nn.Linear(model.module.fc.in_features,
                                        opt.n_finetune_classes)
            model.module.fc = model.module.fc.cuda()

            model.module.freeze_layers(opt.ft_begin_index)
            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            model.fc = nn.Linear(model.fc.in_features, opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
コード例 #28
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet', 'i3d',
        'i3dv2'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == "i3d":

        from models.i3dpt import get_fine_tuning_parameters

        model = i3dpt.I3D(num_classes=opt.n_classes, dropout_prob=0.5)

    elif opt.model == "i3dv2":

        from models.I3D_Pytorch import get_fine_tuning_parameters

        model = I3D_Pytorch.I3D(num_classes=opt.n_classes,
                                dropout_keep_prob=0.5)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)

            if opt.model != "i3d" and opt.model != "i3dv2":
                assert opt.arch == pretrain['arch']
                model.load_state_dict(pretrain['state_dict'])
            else:
                pretrain = {"module." + k: v for k, v in pretrain.items()}
                model_dict = model.state_dict()
                model_dict.update(pretrain)
                model.load_state_dict(model_dict)

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features,
                    opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
コード例 #29
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet',
        'mobilenet', 'mobilenetv2'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)

    elif opt.model == 'mobilenet':
        from models.mobilenet import get_fine_tuning_parameters
        model = mobilenet.get_model(num_classes=opt.n_classes,
                                    sample_size=opt.sample_size,
                                    width_mult=opt.width_mult)
    elif opt.model == 'mobilenetv2':
        from models.mobilenetv2 import get_fine_tuning_parameters
        model = mobilenetv2.get_model(num_classes=opt.n_classes,
                                      sample_size=opt.sample_size,
                                      width_mult=opt.width_mult)

    if not opt.no_cuda:
        if not opt.no_cuda_predict:
            model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            print("Pretrain arch", pretrain['arch'])
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])
            ft_begin_index = opt.ft_begin_index
            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.9),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
                model.module.classifier = model.module.classifier.cuda()
                ft_begin_index = 'complete' if ft_begin_index == 0 else 'last_layer'
            elif opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features,
                    opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()
            print("Finetuning at:", ft_begin_index)
            parameters = get_fine_tuning_parameters(model, ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])
            ft_begin_index = opt.ft_begin_index
            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.9),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
                model.module.classifier = model.module.classifier.cuda()
                ft_begin_index = 'complete' if ft_begin_index == 0 else 'last_layer'
            elif opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)
            print("Finetuning at:", ft_begin_index)
            parameters = get_fine_tuning_parameters(model, ft_begin_index)
            return model, parameters

    return model, model.parameters()
コード例 #30
0
ファイル: model.py プロジェクト: generation21/generation6011
def generate_model(opt):
    assert opt.model in [
        'resnet', 'resnext'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                isSource = opt.isSource,
                transfer_module = opt.transfer_module,
                sourceKind = opt.sourceKind,
                layer_num = opt.layer_num,
                multi_source = opt.multi_source)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                isSource = opt.isSource,
                transfer_module = opt.transfer_module,
                sourceKind = opt.sourceKind,
                layer_num = opt.layer_num,
                multi_source = opt.multi_source)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                isSource = opt.isSource,
                transfer_module = opt.transfer_module,
                sourceKind = opt.sourceKind,
                layer_num = opt.layer_num,
                multi_source = opt.multi_source)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                isSource = opt.isSource,
                transfer_module = opt.transfer_module,
                sourceKind = opt.sourceKind,
                layer_num = opt.layer_num,
                multi_source = opt.multi_source)
        elif opt.model_depth == 101:
            model = resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                isSource = opt.isSource,
                transfer_module = opt.transfer_module,
                sourceKind = opt.sourceKind,
                layer_num = opt.layer_num,
                multi_source = opt.multi_source)
        elif opt.model_depth == 152:
            model = resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                isSource = opt.isSource,
                transfer_module = opt.transfer_module,
                sourceKind = opt.sourceKind,
                layer_num = opt.layer_num,
                multi_source = opt.multi_source)
        elif opt.model_depth == 200:
            model = resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                isSource = opt.isSource,
                transfer_module = opt.transfer_module,
                sourceKind = opt.sourceKind,
                layer_num = opt.layer_num,
                multi_source = opt.multi_source)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                isSource = opt.isSource,
                transfer_module = opt.transfer_module,
                sourceKind = opt.sourceKind,
                layer_num = opt.layer_num,
                multi_source = opt.multi_source)
        elif opt.model_depth == 101:
            model = resnext.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                isSource = opt.isSource,
                transfer_module = opt.transfer_module,
                sourceKind = opt.sourceKind,
                layer_num = opt.layer_num,
                multi_source = opt.multi_source)
        elif opt.model_depth == 152:
            model = resnext.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                isSource = opt.isSource,
                transfer_module = opt.transfer_module,
                sourceKind = opt.sourceKind,
                layer_num = opt.layer_num,
                multi_source = opt.multi_source)
    print(opt.no_cuda)
    print(type(opt.no_cuda))
    if not opt.no_cuda:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            print('loading pretrained model arch', pretrain['arch'], opt.arch)
            assert opt.arch == pretrain['arch']

            pretrained_dict = pretrain['state_dict']
            model_dict = model.state_dict()
            # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
            pretrained_dict = {str.replace(k,'module.',''): v for k,v in pretrained_dict.items()}
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)
            model = model.cuda()
            model = nn.DataParallel(model, device_ids=None)
            if opt.inference == False:
               
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

                parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
                
                print(model)
                return model, parameters
            elif opt.inference:
                model = model.cuda()
                model = nn.DataParallel(model, device_ids=None)
                return model, model.parameters()
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            print('loading pretrained model arch', pretrain['arch'])
            pretrain = torch.load(opt.pretrain_path)
            
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])


            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            model = model.cuda()
            model = nn.DataParallel(model, device_ids=None)
            return model, parameters

    return model, model.parameters()
コード例 #31
0
ファイル: model.py プロジェクト: jarrelscy/3D-ResNets-PyTorch
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                k=opt.wide_resnet_k,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features, opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(
                    model.classifier.in_features, opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                            opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()