Esempio n. 1
0
def generate_model(opt):
    assert opt.model in ['resnet', 'densenet', 'se_resnet']

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(pretrained=True, num_classes=opt.n_classes)
        elif opt.model_depth == 18:
            model = resnet.resnet18(pretrained=True, num_classes=opt.n_classes)
        elif opt.model_depth == 34:
            model = resnet.resnet34(pretrained=True, num_classes=opt.n_classes)
        elif opt.model_depth == 50:
            model = resnet.resnet50(pretrained=True, num_classes=opt.n_classes)
        elif opt.model_depth == 101:
            model = resnet.resnet101(pretrained=True,
                                     num_classes=opt.n_classes)
        elif opt.model_depth == 152:
            model = resnet.resnet152(pretrained=True,
                                     num_classes=opt.n_classes)
        elif opt.model_depth == 200:
            model = resnet.resnet200(pretrained=True,
                                     num_classes=opt.n_classes)
    elif opt.model == 'se_resnet':
        assert opt.model_depth in [18, 34, 50, 101, 152]

        from models.se_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = se_resnet.resnet18(pretrained=True,
                                       num_classes=opt.n_classes)
        elif opt.model_depth == 34:
            model = se_resnet.resnet34(pretrained=True,
                                       num_classes=opt.n_classes)
        elif opt.model_depth == 50:
            model = se_resnet.resnet50(pretrained=True,
                                       num_classes=opt.n_classes)
        elif opt.model_depth == 101:
            model = se_resnet.resnet101(pretrained=True,
                                        num_classes=opt.n_classes)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
        return model, parameters
    else:
        parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
        return model, parameters

    return model, model.parameters()
def generate_model(config):
    model = resnet.resnet101(
        num_classes=config.getint('Network', 'classes'),
        shortcut_type=config.get('Network', 'resnet_shortcut'),
        cardinality=config.getint('Network', 'resnet_cardinality'),
        sample_size=config.getint('Network', 'sample_size'),
        sample_duration=config.getint('Network', 'sample_duration'),
        input_channels=config.getint('Network', 'input_channels'),
    )
    if config.getboolean('Network', 'use_cuda'):
        model = model.cuda()
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
    model = nn.DataParallel(model)
    pretrain_path = config.get('Network', 'pretrain_path')
    if pretrain_path:
        print('loading pretrained model {}'.format(pretrain_path))
        pretrain = torch.load(pretrain_path)

        assert config.get('Network', 'arch') == pretrain['arch']
        model.load_state_dict(pretrain['state_dict'])
        model.module.fc = nn.Linear(
            model.module.fc.in_features,
            config.getint('Network', 'n_finetune_classes'))
        model.module.fc = model.module.fc.cuda()

        parameters = get_fine_tuning_parameters(
            model, config.getint('Network', 'ft_begin_index'))
        return model, parameters

    return model, model.parameters()
def get_model_param(args):
    # assert args.model in ['resnet', 'vgg']

    if args.model == 'resnet':
        assert args.model_depth in [18, 34, 50, 101, 152]

        from models.resnet import get_fine_tuning_parameters

        if args.model_depth == 18:
            model = resnet.resnet18(pretrained=False,
                                    input_size=args.input_size,
                                    num_classes=args.n_classes)
        elif args.model_depth == 34:
            model = resnet.resnet34(pretrained=False,
                                    input_size=args.input_size,
                                    num_classes=args.n_classes)
        elif args.model_depth == 50:
            model = resnet.resnet50(pretrained=False,
                                    input_size=args.input_size,
                                    num_classes=args.n_classes)
        elif args.model_depth == 101:
            model = resnet.resnet101(pretrained=False,
                                     input_size=args.input_size,
                                     num_classes=args.n_classes)
        elif args.model_depth == 152:
            model = resnet.resnet152(pretrained=False,
                                     input_size=args.input_size,
                                     num_classes=args.n_classes)

    # elif args.model == 'vgg':
    #     pass

    # Load pretrained model here
    if args.finetune:
        pretrained_model = model_path[args.arch]
        args.pretrain_path = os.path.join(args.root_path, 'pretrained_models',
                                          pretrained_model)
        print("=> loading pretrained model '{}'...".format(pretrained_model))

        model.load_state_dict(torch.load(args.pretrain_path))

        # Only modify the last layer
        if args.model == 'resnet':
            model.fc = nn.Linear(model.fc.in_features, args.n_finetune_classes)
        # elif args.model == 'vgg':
        #     pass

        parameters = get_fine_tuning_parameters(model, args.ft_begin_index,
                                                args.lr_mult1, args.lr_mult2)
        return model, parameters

    return model, model.parameters()
Esempio n. 4
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    model_type=opt.model_type)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     model_type=opt.model_type)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     model_type=opt.model_type)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     model_type=opt.model_type)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        import os
        # os.environ['CUDA_VISIBLE_DEVICES'] = f'{opt.cuda_id}'
        model = model.cuda(device=opt.cuda_id)
        model = nn.DataParallel(model, device_ids=[0])  # CUDA change

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            print(pretrain['arch'])
            arch = f'{opt.model}-{opt.model_depth}'
            # arch = opt.model + '-' + opt.model_depth
            print(arch)
            assert arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features,
                    opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda(
                    device=opt.cuda_id)
            # elif opt.use_quadriplet:
            #     model = EmbeddingModel(model, opt.n_finetune_classes, not opt.no_cuda, opt.cuda_id)
            else:
                model.module.fc = nn.Sequential(
                    nn.Dropout(0.4),
                    nn.Linear(model.module.fc.in_features, 512), nn.ReLU6(),
                    nn.Dropout(0.4), nn.Linear(512, 128), nn.ReLU6(),
                    nn.Linear(128,
                              opt.n_finetune_classes)).cuda(device=opt.cuda_id)
                # model.module.fc = nn.Linear(model.module.fc.in_features,
                #                             opt.n_finetune_classes)

                # model.module.fc = model.module.fc.cuda(device=opt.cuda_id)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            print(len(list(parameters)), 'params to fine tune')
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)

            return model, parameters

    return model, model.parameters()
Esempio n. 5
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'resnext_fa', 'densenet', 'p3d'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                k=opt.wide_resnet_k,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'p3d':
        assert opt.model_depth in [50, 101, 152]

        if opt.model_depth == 50:
            model = p3d.P3D63(num_classes=opt.n_classes)
        elif opt.model_depth == 101:
            model = p3d.P3D131(num_classes=opt.n_classes)
        elif opt.model_depth == 152:
            model = p3d.P3D199(num_classes=opt.n_classes)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext_fa import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)

    elif opt.model == 'resnext_fa':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext_fa import get_fine_tuning_parameters, get_fine_tuning_parameters_fa

        if opt.model_depth == 50:
            model = resnext_fa.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext_fa.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext_fa.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)


    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            #pdb.set_trace();
            #assert opt.arch == pretrain['arch']

            model_dict = model.state_dict();
            #pdb.set_trace();
            model_dict.update(pretrain['state_dict']);
            model.load_state_dict(model_dict);
            #model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features, opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            # do not need to add new fc layer when finetuning model has the same class num
            elif (opt.n_classes != opt.n_finetune_classes):
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            if (opt.model == 'resnext_fa'):
                parameters = get_fine_tuning_parameters_fa(model, opt.learning_rate)
            else:
                parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(
                    model.classifier.in_features, opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                            opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
Esempio n. 6
0
def generate_model(opt):
    assert opt.model in ['resnet']

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]
        input_chan = 3

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration, \
                input_chan=input_chan)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration, \
                input_chan=input_chan)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration, \
                input_chan=input_chan)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration, \
                input_chan=input_chan)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            if 'arch' in pretrain:
                assert opt.arch == pretrain['arch']
                model.load_state_dict(pretrain['state_dict'])
            else:
                if "state_dict" in pretrain.keys():
                    model.module.load_state_dict(pretrain['state_dict'])
                else:
                    model.module.fc = nn.Linear(model.module.fc.in_features,
                                                128)
                    model.load_state_dict(pretrain['model_state_dict'])

            model.module.fc = nn.Linear(model.module.fc.in_features,
                                        opt.n_finetune_classes)
            model.module.fc = model.module.fc.cuda()

            model.module.freeze_layers(opt.ft_begin_index)
            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            model.fc = nn.Linear(model.fc.in_features, opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
Esempio n. 7
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    ###################################################################
    # ResNet
    ###################################################################
    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
            
    ###################################################################
    # Wider ResNet
    ###################################################################
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                k=opt.wide_resnet_k,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
            
    ###################################################################
    # ResNext
    ###################################################################
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
            
    ###################################################################
    # Pre-ResNet
    ###################################################################
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
            
    ###################################################################
    # DenseNet
    ###################################################################
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    
    ###################################################################
    # Finalizing the model
    ###################################################################
    if not opt.no_cuda:
        
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=opt.device_ids)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']  # ensure that pretrain model is the same architecture

            model.load_state_dict(pretrain['state_dict'])
            
            # change the fc layer output size
            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features, opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()
            
            # 
            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(
                    model.classifier.in_features, opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                            opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
Esempio n. 8
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'resnet_skeleton', 'preresnet', 'wideresnet', 'resnext',
        'densenet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'resnet_skeleton':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet_skeleton import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet_skeleton.resnet_skeleton10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet_skeleton.resnet_skeleton18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet_skeleton.resnet_skeleton34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet_skeleton.resnet_skeleton50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet_skeleton.resnet_skeleton101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet_skeleton.resnet_skeleton152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet_skeleton.resnet_skeleton200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        if opt.cuda_id is None:
            model = model.cuda()
        else:
            model = model.cuda(opt.cuda_id)
        # model = nn.DataParallel(model, device_ids=None)
        if opt.cuda_id is None:
            model = nn.DataParallel(model, device_ids=None)
        else:
            model = nn.DataParallel(model, device_ids=[opt.cuda_id])

        if opt.pretrain_path:
            print('    loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)

            if opt.model == 'resnet_skeleton':
                pretrained_dict = pretrain['state_dict']
                model_dict = model.state_dict()
                # print('----------------')
                # for k, v in pretrained_dict.items():
                #     if k in model_dict:
                #         print(k)
                # print('----------------')

                # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
                pretrained_dict = {
                    k: v
                    for k, v in pretrained_dict.items()
                    if k in model_dict and 'fc' not in k
                }  ## for concatenate
                model_dict.update(pretrained_dict)
                model.load_state_dict(model_dict)
            else:
                assert opt.arch == pretrain['arch']
                model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features,
                    opt.n_finetune_classes)
                if opt.cuda_id is None:
                    model.module.classifier = model.module.classifier.cuda()
                else:
                    model.module.classifier = model.module.classifier.cuda(
                        opt.cuda_id)
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                if opt.cuda_id is None:
                    model.module.fc = model.module.fc.cuda()
                else:
                    model.module.fc = model.module.fc.cuda(opt.cuda_id)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
Esempio n. 9
0
def generate_model(opt):
    # import pdb;pdb.set_trace()
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet', 'se_resnet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    channels=opt.channels)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    channels=opt.channels)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    channels=opt.channels)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    channels=opt.channels)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     channels=opt.channels)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     channels=opt.channels)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     channels=opt.channels)

    elif opt.model == 'se_resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.se_resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = se_resnet.se_resnet10(num_classes=opt.n_classes,
                                          shortcut_type=opt.resnet_shortcut,
                                          sample_size=opt.sample_size,
                                          sample_duration=opt.sample_duration,
                                          channels=opt.channels)
        elif opt.model_depth == 18:
            model = se_resnet.se_resnet18(num_classes=opt.n_classes,
                                          shortcut_type=opt.resnet_shortcut,
                                          sample_size=opt.sample_size,
                                          sample_duration=opt.sample_duration,
                                          channels=opt.channels)
        elif opt.model_depth == 34:
            model = se_resnet.se_resnet34(num_classes=opt.n_classes,
                                          shortcut_type=opt.resnet_shortcut,
                                          sample_size=opt.sample_size,
                                          sample_duration=opt.sample_duration,
                                          channels=opt.channels)
        elif opt.model_depth == 50:
            model = se_resnet.se_resnet50(num_classes=opt.n_classes,
                                          shortcut_type=opt.resnet_shortcut,
                                          sample_size=opt.sample_size,
                                          sample_duration=opt.sample_duration,
                                          channels=opt.channels)
        elif opt.model_depth == 101:
            model = se_resnet.se_resnet101(num_classes=opt.n_classes,
                                           shortcut_type=opt.resnet_shortcut,
                                           sample_size=opt.sample_size,
                                           sample_duration=opt.sample_duration,
                                           channels=opt.channels)
        elif opt.model_depth == 152:
            model = se_resnet.se_resnet152(num_classes=opt.n_classes,
                                           shortcut_type=opt.resnet_shortcut,
                                           sample_size=opt.sample_size,
                                           sample_duration=opt.sample_duration,
                                           channels=opt.channels)
        elif opt.model_depth == 200:
            model = se_resnet.se_resnet200(num_classes=opt.n_classes,
                                           shortcut_type=opt.resnet_shortcut,
                                           sample_size=opt.sample_size,
                                           sample_duration=opt.sample_duration,
                                           channels=opt.channels)

    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         channels=opt.channels)

    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     channels=opt.channels)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration,
                                      channels=opt.channels)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration,
                                      channels=opt.channels)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                channels=opt.channels)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                channels=opt.channels)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                channels=opt.channels)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                channels=opt.channels)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                channels=opt.channels)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                channels=opt.channels)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         channels=opt.channels)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         channels=opt.channels)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         channels=opt.channels)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         channels=opt.channels)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            pretrain_dict = pretrain['state_dict']
            model_dict = model.state_dict()
            pretrain_dict = {
                k: v
                for k, v in pretrain_dict.items() if k in model_dict
            }
            # 更新现有的model_dict
            w = pretrain_dict['module.conv1.weight']
            pretrain_dict['module.conv1.weight'] = torch.nn.Parameter(
                w[:, :1, :, :])
            w_fc = pretrain_dict['module.fc.weight']
            pretrain_dict['module.fc.weight'] = torch.nn.Parameter(
                w_fc[:opt.n_finetune_classes, :])
            w_bias = pretrain_dict['module.fc.bias']
            pretrain_dict['module.fc.bias'] = torch.nn.Parameter(
                w_bias[:opt.n_finetune_classes])
            model_dict.update(pretrain_dict)
            model.load_state_dict(model_dict)

            #model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features,
                    opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
Esempio n. 10
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet','standard'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                k=opt.wide_resnet_k,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)

    elif opt.model == 'standard':
        from models.C3D_model import get_fine_tuning_parameters
        num_classes=opt.n_classes
        model = C3D_model.C3D(num_classes)
    #s1m = torch.load('c3d.pickle')
    #reset last layer to 400 classes
        #s1m['fc8.weight']= torch.FloatTensor(num_classes,4096)
    #reset bias to tensor of size 400
        #s1m['fc8.bias']= torch.FloatTensor(num_classes)
        #load weights into C3D model
        #model.load_state_dict(s1m)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features, opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            elif opt.model =='resnet':
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc =  model.module.fc.cuda()
            else:

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(
                    model.classifier.in_features, opt.n_finetune_classes)
            elif opt.model == 'resnet':
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
            else:
                model.module.fc8 = nn.Linear(model.module.fc8.in_features, opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
Esempio n. 11
0
def generate_model(opt, modality):
    assert opt.model in ['resnet', 'resnetl', 'resnext', 'c3d']

    if opt.model == 'resnet':
        # assert opt.model_depth in [10]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.n_frames_per_clip)

        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.n_frames_per_clip)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.n_frames_per_clip)

    elif opt.model == 'resnetl':
        assert opt.model_depth in [10]

        from models.resnetl import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnetl.resnetl10(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.n_frames_per_clip)

    elif opt.model == 'resnext':
        assert opt.model_depth in [101]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.n_frames_per_clip)
    elif opt.model == 'c3d':
        assert opt.model_depth in [10]

        from models.c3d import get_fine_tuning_parameters
        if opt.model_depth == 10:

            model = c3d.c3d_v1(sample_size=opt.sample_size,
                               sample_duration=opt.n_frames_per_clip,
                               num_classes=opt.n_classes)
    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path, map_location='cpu')
            # assert opt.arch == pretrain['arch']
            feature_state_dict = {
                key: value
                for key, value in pretrain['state_dict'].items()
                if key not in ['module.fc.weight', 'module.fc.bias']
            }
            state_dict = deepcopy(model.state_dict())
            state_dict.update(feature_state_dict)
            model.load_state_dict(state_dict)

            # model.load_state_dict(pretrain['state_dict'])

        if modality == 'RGB' and opt.model != 'c3d':
            print("[INFO]: RGB model is used for init model")
            # model = _modify_first_conv_layer(model,3,3) ##### Check models trained (3,7,7) or (7,7,7)
        elif modality == 'Depth':
            print(
                "[INFO]: Converting the pretrained model to Depth init model")
            model = _construct_depth_model(model)
            print("[INFO]: Done. Flow model ready.")
        elif modality == 'RGB-D':
            print(
                "[INFO]: Converting the pretrained model to RGB+D init model")
            model = _construct_rgbdepth_model(model)
            print("[INFO]: Done. RGB-D model ready.")

        # Check first kernel size
        modules = list(model.modules())
        first_conv_idx = list(
            filter(lambda x: isinstance(modules[x], nn.Conv3d),
                   list(range(len(modules)))))[0]

        conv_layer = modules[first_conv_idx]
        if conv_layer.kernel_size[0] > opt.n_frames_per_clip:
            print("[INFO]: RGB model is used for init model")
            model = _modify_first_conv_layer(model,
                                             int(opt.n_frames_per_clip / 2), 1)

        if opt.model == 'c3d':  # CHECK HERE
            model.module.fc = nn.Linear(model.module.fc[0].in_features,
                                        model.module.fc[0].out_features)
            model.module.fc = model.module.fc.cuda()
        else:
            model.module.fc = nn.Linear(model.module.fc.in_features,
                                        opt.n_finetune_classes)
            model.module.fc = model.module.fc.cuda()

        parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
        return model, parameters

    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path, map_location='cpu')
            # assert opt.arch == pretrain['arch']
            feature_state_dict = {
                key.replace('module.', ''): value
                for key, value in pretrain['state_dict'].items()
                if key not in ['module.fc.weight', 'module.fc.bias']
            }
            state_dict = deepcopy(model.state_dict())
            state_dict.update(feature_state_dict)
            model.load_state_dict(state_dict)
            # model.load_state_dict(pretrain['state_dict'])

        if modality == 'RGB' and opt.model != 'c3d':
            print("[INFO]: RGB model is used for init model")
            model = _modify_first_conv_layer(model, 3, 3)
        elif modality == 'Depth':
            print(
                "[INFO]: Converting the pretrained model to Depth init model")
            model = _construct_depth_model(model)
            print("[INFO]: Deoth model ready.")
        elif modality == 'RGB-D':
            print(
                "[INFO]: Converting the pretrained model to RGB-D init model")
            model = _construct_rgbdepth_model(model)
            print("[INFO]: Done. RGB-D model ready.")

        # Check first kernel size
        modules = list(model.modules())
        first_conv_idx = list(
            filter(lambda x: isinstance(modules[x], nn.Conv3d),
                   list(range(len(modules)))))[0]

        conv_layer = modules[first_conv_idx]
        if conv_layer.kernel_size[0] > opt.n_frames_per_clip:
            print("[INFO]: RGB model is used for init model")
            model = _modify_first_conv_layer(model,
                                             int(opt.n_frames_per_clip / 2), 1)

        if opt.model == 'c3d':  # CHECK HERE
            model.fc = nn.Linear(model.fc[0].in_features,
                                 model.fc[0].out_features)
        else:
            model.fc = nn.Linear(model.fc.in_features, opt.n_finetune_classes)

        parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
        return model, parameters

    return model, model.parameters()
Esempio n. 12
0
def generate_model(opt):

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    audiovisual=opt.audiovisual)
        elif opt.model_depth == 18:
            model = resnet.resnet18(shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    audiovisual=opt.audiovisual)
        elif opt.model_depth == 34:
            model = resnet.resnet34(shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    audiovisual=opt.audiovisual)
        elif opt.model_depth == 50:
            model = resnet.resnet50(shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    audiovisual=opt.audiovisual)
        elif opt.model_depth == 101:
            model = resnet.resnet101(shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     audiovisual=opt.audiovisual)
        elif opt.model_depth == 152:
            model = resnet.resnet152(shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     audiovisual=opt.audiovisual)
        elif opt.model_depth == 200:
            model = resnet.resnet200(shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     audiovisual=opt.audiovisual)
    else:
        raise ValueError('Model type not supported')

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

    if opt.pretrain_path:
        print('loading pretrained model {}'.format(opt.pretrain_path))
        pretrain = torch.load(opt.pretrain_path)
        assert opt.arch == pretrain['arch']
        model.load_state_dict(pretrain['state_dict'], strict=0)

    if opt.audio_pretrain_path:
        print('loading pretrained audio model {}'.format(
            opt.audio_pretrain_path))
        pretrain_audio = torch.load(opt.audio_pretrain_path)
        model.load_state_dict(pretrain_audio, strict=0)

    (parameters,
     name_parameters) = get_fine_tuning_parameters(model,
                                                   opt.learning_rate_sal,
                                                   opt.weight_decay)

    return model, parameters
Esempio n. 13
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            #assert opt.arch == pretrain['arch']

            pretrained_dict = pretrain['state_dict']
            # print("PRETRAIN BEFORE:", pretrained_dict.keys())
            model_dict = model.state_dict()

            #print("Current model:", model.state_dict().keys())
            #print("Pretrained model:", pretrain['state_dict'].keys())

            # 1. filter out unnecessary keys
            #pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
            # 2. overwrite entries in the existing state dict

            pretrained_dict.update(model_dict)
            #print("Pretrained model after:", len(model_dict.keys()))
            # 3. load the new state dict
            model.load_state_dict(pretrained_dict)

            model.module.fc = nn.Linear(model.module.fc.in_features,
                                        opt.n_finetune_classes)
            model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)

            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            #assert opt.arch == pretrain['arch']

            print("pretrain", pretrain['state_dict'])
            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
Esempio n. 14
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]
    
    if(opt.pretrain_path):
      if(opt.is_rgb):
        ch = 3
      elif(opt.is_depth):
        ch = 1
      elif(opt.is_rgb and opt.is_depth):
        ch = 4
        

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration, channel=ch)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration, channel=ch)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration, channel=ch)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration, channel=ch)
        elif opt.model_depth == 101:
            model = resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration, channel=ch)
        elif opt.model_depth == 152:
            model = resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration, channel=ch)
        elif opt.model_depth == 200:
            model = resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration, channel=ch)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                k=opt.wide_resnet_k,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']
            
            model_dict = model.state_dict()
            pretrained_dict = pretrain['state_dict']
            # 1. filter out unnecessary keys
            if((not opt.is_rgb and opt.is_depth) or (opt.is_rgb and opt.is_depth) ):
               pretrained_dict = {k: v for k, v in pretrained_dict.items() if k != 'module.conv1.weight'}
            #for k, v in pretrained_dict.items():
            #  print(k)
            # 2. overwrite entries in the existing state dict
            model_dict.update(pretrained_dict)
            # 3. load the new state dict
            model.load_state_dict(model_dict)


            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features, opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(
                    model.classifier.in_features, opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                            opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters



    return model
Esempio n. 15
0
def generate_model(opt):
    assert opt.clip_model in ['resnet', 'r2plus1d', 'avts', 'i3d']

    if opt.clip_model == 'resnet':
        assert opt.clip_model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters
        from models.resnet import get_fine_tuning_parameters_layer_lr

        if opt.clip_model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.clip_model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.clip_model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.clip_model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    no_last_fc=opt.no_last_fc)
        elif opt.clip_model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.clip_model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.clip_model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)

    elif opt.clip_model.lower() in ['r2+1d', 'r2.5d', 'r2plus1d']:
        from models.r2plus1d import get_fine_tuning_parameters, get_fine_tuning_parameters_layer_lr
        print("Making R2+1D model, depth", opt.clip_model_depth)
        model = r2plus1d.r2plus1d_34(num_classes=opt.n_classes)

    elif opt.clip_model.lower() in ['avts']:
        from models.avts import get_fine_tuning_parameters, get_fine_tuning_parameters_layer_lr
        print("Making AVTS model")
        model = avts.mc3_avts()
        model.add_module("flatten", torch.nn.Flatten(1))
        model.add_module("fc", nn.Linear(256, opt.n_classes))

    elif opt.clip_model == 'i3d':
        from models import i3res
        if opt.clip_model_depth == 50:
            print("Making I3D model, depth", opt.clip_model_depth)
            model = i3res.i3res_50(sample_size=opt.sample_size,
                                   sample_duration=opt.sample_duration,
                                   num_classes=opt.n_classes)
        elif opt.clip_model_depth == 34:
            print("Making I3D model, depth", opt.clip_model_depth)
            model = i3res.i3res_34(sample_size=opt.sample_size,
                                   sample_duration=opt.sample_duration,
                                   num_classes=opt.n_classes)

    if opt.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if opt.gpu is not None:
            torch.cuda.set_device(opt.gpu)
            model.cuda(opt.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            opt.batch_size = int(opt.batch_size / opt.ngpus_per_node)
            opt.n_threads = int(opt.n_threads / opt.ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[opt.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif opt.gpu is not None:
        torch.cuda.set_device(opt.gpu)
        model = model.cuda(opt.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()

    if opt.pretrain_path:
        print('loading pretrained model {}'.format(opt.pretrain_path))
        pretrain = torch.load(opt.pretrain_path)
        model.load_state_dict(pretrain['state_dict'])

        model.module.fc = nn.Linear(model.module.fc.in_features,
                                    opt.n_finetune_classes)
        model.module.fc = model.module.fc.cuda()

        if opt.layer_lr is not None:
            parameters = get_fine_tuning_parameters_layer_lr(
                model, opt.ft_begin_index, opt.layer_lr)
        else:
            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)

        return model, parameters

    # else:
    #     if opt.layer_lr is not None:
    #         parameters = get_fine_tuning_parameters_layer_lr(model, opt.ft_begin_index, opt.layer_lr)
    #     else:
    #         parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
    #     return model, parameters

    return model
Esempio n. 16
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features,
                    opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            # strip off the 'module.' for each module; this get's added when a model is saved using nn.DataParallel
            pretrain['state_dict'] = {
                k[7:]: v
                for k, v in pretrain['state_dict'].items()
            }
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
Esempio n. 17
0
def generate_model(opt):
    assert opt.model in [
        'resnet',
        'preresnet',
        'wideresnet',
        'resnext',
        'densenet',
        'c3d',
        'c2d',
        'c2d_exp',
        'c2d_coord',
        'c3d_color',
        'c2d_pt',
        'c2d_pt2',
        'c2d_pt5',
        'c2d_pt7',
        'c2d_pt_exp',
        'c2d_pt2_exp',
        'c2d_pt5_exp',
        'c2d_pt_exp_avg',
        'c2d_pt_exp_sep',
        'c3d_pt_exp',
        'c2d_pt_exp_init',
        'c2d_pt_expc',
        'resnet18_exp',
        'resnet34_exp',
        'resnet50_exp',
        'resnet101_exp',
        'resnet152_exp',
        'resnext50_32x4d_exp',
        'resnext101_32x8d_exp',
        'wide_resnet50_2_exp',
        'wide_resnet101_2_exp',
        'resnet18_pt_exp',
        'resnet34_pt_exp',
        'resnet50_pt_exp',
        'resnet101_pt_exp',
        'resnet152_pt_exp',
        'resnext50_32x4d_pt_exp',
        'resnext101_32x8d_pt_exp',
        'wide_resnet50_2_pt_exp',
        'wide_resnet101_2_pt_exp',
        # decoder
        'stsrresnetexp',
        'spc',
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)

    elif opt.model == 'c3d':
        model = c3d.C3D(num_classes=opt.n_classes,
                        sample_size=opt.sample_size,
                        sample_duration=opt.sample_duration)
    elif opt.model == 'c3d_color':
        model = c3d_color.C3D(num_classes=opt.n_classes,
                              sample_size=opt.sample_size,
                              sample_duration=opt.sample_duration)
    elif opt.model == 'spc':
        model = spc.SPC(num_classes=opt.n_classes,
                        sample_size=opt.sample_size,
                        sample_duration=opt.sample_duration)
    elif opt.model == 'c2d':
        model = c2d.C2D(num_classes=opt.n_classes,
                        sample_size=opt.sample_size,
                        sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt':
        model = c2d_pt.C2DPt(num_classes=opt.n_classes,
                             sample_size=opt.sample_size,
                             sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt2':
        model = c2d_pt2.C2DPt(num_classes=opt.n_classes,
                              sample_size=opt.sample_size,
                              sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt5':
        model = c2d_pt5.C2DPt(num_classes=opt.n_classes,
                              sample_size=opt.sample_size,
                              sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt7':
        model = c2d_pt7.C2DPt(num_classes=opt.n_classes,
                              sample_size=opt.sample_size,
                              sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_exp':
        model = c2d_exp.C2DExp(num_classes=opt.n_classes,
                               sample_size=opt.sample_size,
                               sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt_exp':
        model = c2d_pt_exp.C2DPtExp(num_classes=opt.n_classes,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt_expc':
        model = c2d_pt_expc.C2DPtExpC(num_classes=opt.n_classes,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt_exp_init':
        model = c2d_pt_exp_init.C2DPtExp(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'c3d_pt_exp':
        model = c3d_pt_exp.C3DPtExp(num_classes=opt.n_classes,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt_exp_avg':
        model = c2d_pt_exp_avg.C2DPtExpAvg(num_classes=opt.n_classes,
                                           sample_size=opt.sample_size,
                                           sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt_exp_sep':
        model = c2d_pt_exp_sep.C2DPtExpSep(num_classes=opt.n_classes,
                                           sample_size=opt.sample_size,
                                           sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt5_exp':
        model = c2d_pt5_exp.C2DPtExp(num_classes=opt.n_classes,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_pt2_exp':
        model = c2d_pt2_exp.C2DPtExp(num_classes=opt.n_classes,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'c2d_coord':
        model = c2d_coord.C2DCoord(num_classes=opt.n_classes,
                                   sample_size=opt.sample_size,
                                   sample_duration=opt.sample_duration)

    elif opt.model == 'resnet18_exp':
        model = resnet_exp.resnet18(pretrained=False,
                                    progress=True,
                                    num_classes=opt.n_classes,
                                    sample_duration=opt.sample_duration)
    elif opt.model == 'resnet34_exp':
        model = resnet_exp.resnet34(pretrained=False,
                                    progress=True,
                                    num_classes=opt.n_classes,
                                    sample_duration=opt.sample_duration)
    elif opt.model == 'resnet50_exp':
        model = resnet_exp.resnet50(pretrained=False,
                                    progress=True,
                                    num_classes=opt.n_classes,
                                    sample_duration=opt.sample_duration)
    elif opt.model == 'resnet101_exp':
        model = resnet_exp.resnet101(pretrained=False,
                                     progress=True,
                                     num_classes=opt.n_classes,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'resnet152_exp':
        model = resnet_exp.resnet152(pretrained=False,
                                     progress=True,
                                     num_classes=opt.n_classes,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'resnext50_32x4d_exp':
        model = resnet_exp.resnext50_32x4d(pretrained=False,
                                           progress=True,
                                           num_classes=opt.n_classes,
                                           sample_duration=opt.sample_duration)
    elif opt.model == 'resnext101_32x8d_exp':
        model = resnet_exp.resnext101_32x8d(
            pretrained=False,
            progress=True,
            num_classes=opt.n_classes,
            sample_duration=opt.sample_duration)
    elif opt.model == 'wide_resnet50_2_exp':
        model = resnet_exp.wide_resnet50_2(pretrained=False,
                                           progress=True,
                                           num_classes=opt.n_classes,
                                           sample_duration=opt.sample_duration)
    elif opt.model == 'wide_resnet101_2_exp':
        model = resnet_exp.wide_resnet101_2(
            pretrained=False,
            progress=True,
            num_classes=opt.n_classes,
            sample_duration=opt.sample_duration)

    elif opt.model == 'resnet18_pt_exp':
        model = resnet_pt_exp.resnet18(pretrained=False,
                                       progress=True,
                                       num_classes=opt.n_classes,
                                       sample_duration=opt.sample_duration)
    elif opt.model == 'resnet34_pt_exp':
        model = resnet_pt_exp.resnet34(pretrained=False,
                                       progress=True,
                                       num_classes=opt.n_classes,
                                       sample_duration=opt.sample_duration)
    elif opt.model == 'resnet50_pt_exp':
        model = resnet_pt_exp.resnet50(pretrained=False,
                                       progress=True,
                                       num_classes=opt.n_classes,
                                       sample_duration=opt.sample_duration)
    elif opt.model == 'resnet101_pt_exp':
        model = resnet_pt_exp.resnet101(pretrained=False,
                                        progress=True,
                                        num_classes=opt.n_classes,
                                        sample_duration=opt.sample_duration)
    elif opt.model == 'resnet152_pt_exp':
        model = resnet_pt_exp.resnet152(pretrained=False,
                                        progress=True,
                                        num_classes=opt.n_classes,
                                        sample_duration=opt.sample_duration)
    elif opt.model == 'resnext50_32x4d_pt_exp':
        model = resnet_pt_exp.resnext50_32x4d(
            pretrained=False,
            progress=True,
            num_classes=opt.n_classes,
            sample_duration=opt.sample_duration)
    elif opt.model == 'resnext101_32x8d_pt_exp':
        model = resnet_pt_exp.resnext101_32x8d(
            pretrained=False,
            progress=True,
            num_classes=opt.n_classes,
            sample_duration=opt.sample_duration)
    elif opt.model == 'wide_resnet50_2_pt_exp':
        model = resnet_pt_exp.wide_resnet50_2(
            pretrained=False,
            progress=True,
            num_classes=opt.n_classes,
            sample_duration=opt.sample_duration)
    elif opt.model == 'wide_resnet101_2_pt_exp':
        model = resnet_pt_exp.wide_resnet101_2(
            pretrained=False,
            progress=True,
            num_classes=opt.n_classes,
            sample_duration=opt.sample_duration)

    elif opt.model == 'stsrresnetexp':
        model = decoder.STSRResNetExp(sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features,
                    opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
Esempio n. 18
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        args = {
            "num_classes": opt.n_classes,
            "shortcut_type": opt.resnet_shortcut,
            "sample_size": opt.sample_size,
            "sample_duration": opt.sample_duration
        }

        if opt.model_depth == 10:
            model = resnet.resnet10(**args)
        elif opt.model_depth == 18:
            model = resnet.resnet18(**args)
        elif opt.model_depth == 34:
            model = resnet.resnet34(**args)
        elif opt.model_depth == 50:
            model = resnet.resnet50(**args)
        elif opt.model_depth == 101:
            model = resnet.resnet101(**args)
        elif opt.model_depth == 152:
            model = resnet.resnet152(**args)
        elif opt.model_depth == 200:
            model = resnet.resnet200(**args)

    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)

    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        args = {
            "num_classes": opt.n_classes,
            "shortcut_type": opt.resnet_shortcut,
            "cardinality": opt.resnext_cardinality,
            "sample_size": opt.sample_size,
            "sample_duration": opt.sample_duration
        }

        if opt.model_depth == 50:
            model = resnext.resnet50(**args)
        elif opt.model_depth == 101:
            model = resnext.resnet101(**args)
        elif opt.model_depth == 152:
            model = resnext.resnet152(**args)

    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        args = {
            "num_classes": opt.n_classes,
            "shortcut_type": opt.resnet_shortcut,
            "sample_size": opt.sample_size,
            "sample_duration": opt.sample_duration
        }

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(**args)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(**args)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(**args)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(**args)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(**args)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(**args)

    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        args = {
            "num_classes": opt.n_classes,
            "sample_size": opt.sample_size,
            "sample_duration": opt.sample_duration
        }

        if opt.model_depth == 121:
            model = densenet.densenet121(**args)
        elif opt.model_depth == 169:
            model = densenet.densenet169(**args)
        elif opt.model_depth == 201:
            model = densenet.densenet201(**args)
        elif opt.model_depth == 264:
            model = densenet.densenet264(**args)

    if opt.no_cuda:
        device = 'cpu'
    else:
        device = 'cuda'
        model = model.to(device)
        model = nn.DataParallel(model, device_ids=None)

    if opt.pretrain_path:
        print('loading pretrained model {}'.format(opt.pretrain_path))
        pretrain = torch.load(opt.pretrain_path, map_location=device)
        assert opt.arch == pretrain['arch']

        model.load_state_dict(pretrain['state_dict'])

        if opt.model == 'densenet':
            model.module.classifier = nn.Linear(
                model.module.classifier.in_features, opt.n_finetune_classes)
            model.module.classifier = model.module.classifier.to(device)
        else:
            model.module.fc = nn.Linear(model.module.fc.in_features,
                                        opt.n_finetune_classes)
            model.module.fc = model.module.fc.to(device)

        parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
        return model, parameters

    return model, model.parameters()
Esempio n. 19
0
def generate_model(opt):
    assert opt.model in [
        'resnet3D', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    if opt.model == 'resnet3D':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                k=opt.wide_resnet_k,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        # model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features, opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain:
            print('loading pretrained model {}'.format(opt.pretrain))
            pretrain = torch.load(opt.pretrain)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(
                    model.classifier.in_features, opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                            opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
















# import torch.nn as nn
# import math
# import numpy as np
# import torch
# from torch.autograd import Variable
# from torch.nn.parameter import Parameter
# import random
# import torch.utils.model_zoo as model_zoo
#
# def initweights(m):
#     orthogonal_flag = False
#     for layer in m.modules():
#         if isinstance(layer, nn.Conv2d):
#             n = layer.kernel_size[0] * layer.kernel_size[1] * layer.out_channels
#             layer.weight.data.normal_(0, math.sqrt(2. / n))
#
#
#
#             # orthogonal initialize
#             """Reference:
#             [1] Saxe, Andrew M., James L. McClelland, and Surya Ganguli.
#                "Exact solutions to the nonlinear dynamics of learning in deep
#                linear neural networks." arXiv preprint arXiv:1312.6120 (2013)."""
#             if orthogonal_flag:
#                 weight_shape = layer.weight.data.cpu().numpy().shape
#                 u, _, v = np.linalg.svd(layer.weight.data.cpu().numpy(), full_matrices=False)
#                 flat_shape = (weight_shape[0], np.prod(weight_shape[1:]))
#                 q = u if u.shape == flat_shape else v
#                 q = q.reshape(weight_shape)
#                 layer.weight.data.copy_(torch.Tensor(q))
#
#         elif isinstance(layer, nn.BatchNorm2d):
#             layer.weight.data.fill_(1)
#             layer.bias.data.zero_()
#         elif isinstance(layer, nn.Linear):
#             layer.bias.data.zero_()
#
# def conv3x3(in_planes, out_planes, stride=1):
#     """3x3 convolution with padding"""
#     return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
#                      padding=1, bias=False)
#
# class BasicBlock(nn.Module):
#     expansion = 1
#
#     def __init__(self, inplanes, planes, stride=1, downsample=None):
#         super(BasicBlock, self).__init__()
#         self.conv1 = conv3x3(inplanes, planes, stride)
#         self.bn1 = nn.BatchNorm2d(planes)
#         self.relu = nn.ReLU(inplace=True)
#         self.conv2 = conv3x3(planes, planes)
#         self.bn2 = nn.BatchNorm2d(planes)
#         self.downsample = downsample
#         self.stride = stride
#
#     def forward(self, x):
#         residual = x
#
#         out = self.conv1(x)
#         out = self.bn1(out)
#         out = self.relu(out)
#
#         out = self.conv2(out)
#         out = self.bn2(out)
#
#         if self.downsample is not None:
#             residual = self.downsample(x)
#
#         out += residual
#         out = self.relu(out)
#
#         return out
#
# class Bottleneck(nn.Module):
#     expansion = 4
#
#     def __init__(self, inplanes, planes, stride=1, downsample=None):
#         super(Bottleneck, self).__init__()
#         self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
#         self.bn1 = nn.BatchNorm2d(planes)
#         self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
#                                padding=1, bias=False)
#         self.bn2 = nn.BatchNorm2d(planes)
#         self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
#         self.bn3 = nn.BatchNorm2d(planes * 4)
#         self.relu = nn.ReLU(inplace=True)
#         self.downsample = downsample
#         self.stride = stride
#
#     def forward(self, x):
#         residual = x
#
#         out = self.conv1(x)
#         out = self.bn1(out)
#         out = self.relu(out)
#
#         out = self.conv2(out)
#         out = self.bn2(out)
#         out = self.relu(out)
#
#         out = self.conv3(out)
#         out = self.bn3(out)
#
#         if self.downsample is not None:
#             residual = self.downsample(x)
#
#         out += residual
#         out = self.relu(out)
#
#         return out
#
#
# class ResNet(nn.Module):
#
#     def __init__(self, block, layers, num_classes=1000):
#         self.inplanes = 64
#         super(ResNet, self).__init__()
#         self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
#                                bias=False)
#         self.bn1 = nn.BatchNorm2d(64)
#         self.relu = nn.ReLU(inplace=True)
#         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
#         self.layer1 = self._make_layer(block, 64, layers[0])
#         self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
#         self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
#         self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
#         self.avgpool = nn.AvgPool2d(7, stride=1)
#         self.fc = nn.Linear(512 * block.expansion, num_classes)
#
#         self.apply(initweights)
#
#     def _make_layer(self, block, planes, blocks, stride=1):
#         downsample = None
#         if stride != 1 or self.inplanes != planes * block.expansion:
#             downsample = nn.Sequential(
#                 nn.Conv2d(self.inplanes, planes * block.expansion,
#                           kernel_size=1, stride=stride, bias=False),
#                 nn.BatchNorm2d(planes * block.expansion),
#             )
#
#         layers = []
#         layers.append(block(self.inplanes, planes, stride, downsample))
#         self.inplanes = planes * block.expansion
#         for i in range(1, blocks):
#             layers.append(block(self.inplanes, planes))
#
#         return nn.Sequential(*layers)
#
#     def forward(self, x):
#         x = self.conv1(x)
#         x = self.bn1(x)
#         x = self.relu(x)
#         x = self.maxpool(x)
#
#         x = self.layer1(x)
#         x = self.layer2(x)
#         x = self.layer3(x)
#         x = self.layer4(x)
#
#         x = self.avgpool(x)
#         x = x.view(x.size(0), -1)
#         x = self.fc(x)
#
#         return x
#
# class ResNetImageNet(nn.Module):
#
#     def __init__(self, opt, num_classes=1000):
#         super(ResNetImageNet, self).__init__()
#         self.opt = opt
#         self.depth = opt.depth
#         self.num_classes = num_classes
#         if self.depth == 18:
#         	self.model = ResNet(BasicBlock, [2, 2, 2, 2], self.num_classes)
#         elif self.depth == 34:
#         	self.model = ResNet(BasicBlock, [3, 4, 6, 3], self.num_classes)
#         elif self.depth == 50:
#         	self.model = ResNet(Bottleneck, [3, 4, 6, 3], self.num_classes)
#         elif self.depth == 101:
#         	self.model = ResNet(Bottleneck, [3, 4, 23, 3], self.num_classes)
#         elif self.depth == 152:
#         	self.model = ResNet(Bottleneck, [3, 8, 36, 3], self.num_classes)
#
#
#     def forward(self, x):
#         x = self.model(x)
#         return x
#
#
#
# # def resnet18(pretrained=False, **kwargs):
# #     """Constructs a ResNet-18 model.
# #     Args:
# #         pretrained (bool): If True, returns a model pre-trained on ImageNet
# #     """
# #     model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
# #     if pretrained:
# #         model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
# #     return model
Esempio n. 20
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'resnet_AE', 'resnet_mask', 'resnet_comp', 'unet', 'icnet',
        'icnet_res', 'icnet_res_2D', 'icnet_res_2Dt', 'icnet_DBI',
        'icnet_deep', 'icnet_deep_gate', 'icnet_deep_gate_2step'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)

    elif opt.model == 'resnet_AE':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet_AE import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = resnet_AE.resnet18(num_classes=opt.n_classes,
                                       shortcut_type=opt.resnet_shortcut,
                                       sample_size=opt.sample_size,
                                       sample_duration=opt.sample_duration,
                                       is_gray=opt.is_gray,
                                       opt=opt)
        elif opt.model_depth == 34:
            model = resnet_AE.resnet34(num_classes=opt.n_classes,
                                       shortcut_type=opt.resnet_shortcut,
                                       sample_size=opt.sample_size,
                                       sample_duration=opt.sample_duration,
                                       is_gray=opt.is_gray,
                                       opt=opt)
        elif opt.model_depth == 50:
            model = resnet_AE.resnet50(num_classes=opt.n_classes,
                                       shortcut_type=opt.resnet_shortcut,
                                       sample_size=opt.sample_size,
                                       sample_duration=opt.sample_duration,
                                       is_gray=opt.is_gray,
                                       opt=opt)

    elif opt.model == 'resnet_mask':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet_mask import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = resnet_mask.resnet18(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         is_gray=opt.is_gray,
                                         opt=opt)
        elif opt.model_depth == 34:
            model = resnet_mask.resnet34(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         is_gray=opt.is_gray,
                                         opt=opt)
        elif opt.model_depth == 50:
            model = resnet_mask.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         is_gray=opt.is_gray,
                                         opt=opt)

    elif opt.model == 'resnet_comp':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet_comp import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = resnet_comp.resnet18(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         is_gray=opt.is_gray,
                                         opt=opt)
        elif opt.model_depth == 34:
            model = resnet_comp.resnet34(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         is_gray=opt.is_gray,
                                         opt=opt)
        elif opt.model_depth == 50:
            model = resnet_comp.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration,
                                         is_gray=opt.is_gray,
                                         opt=opt)
    elif opt.model == 'unet':
        model = unet_mask.UNet3D(opt=opt)

    elif opt.model == 'icnet':
        model = icnet_mask.ICNet3D(opt=opt)
    elif opt.model == 'icnet_res':
        model = icnet_res.ICNetResidual3D(opt=opt)
    elif opt.model == 'icnet_res_2D':
        model = icnet_res.ICNetResidual2D(opt=opt)
    elif opt.model == 'icnet_res_2Dt':
        model = icnet_res.ICNetResidual2Dt(opt=opt)
    elif opt.model == 'icnet_DBI':
        model = icnet_res.ICNetResidual_DBI(opt=opt)
    elif opt.model == 'icnet_deep':
        model = icnet_res.ICNetDeep(opt=opt)
    elif opt.model == 'icnet_deep_gate':
        model = icnet_res.ICNetDeepGate(opt=opt)
    elif opt.model == 'icnet_deep_gate_2step':
        model = icnet_res.ICNetDeepGate2step(opt=opt)
    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            print('loading from', pretrain['arch'])

            child_dict = model.state_dict()
            if opt.two_step and opt.test:
                parent_list = pretrain['state_dict_1'].keys()
            else:
                parent_list = pretrain['state_dict'].keys()

            print('Not loaded :')
            parent_dict = {}
            for chi, _ in child_dict.items():
                # pdb.set_trace()
                # if ('coarse' in chi):
                # chi_ori = chi
                # chi = 'module.' + ".".join(chi_ori.split('.')[2:])

                if chi in parent_list:
                    if opt.two_step and opt.test:
                        parent_dict[chi] = pretrain['state_dict_1'][chi]
                    else:
                        parent_dict[chi] = pretrain['state_dict'][chi]
                else:
                    print(chi)
            print('length :', len(parent_dict.keys()))
            child_dict.update(parent_dict)
            model.load_state_dict(child_dict)

            if not opt.is_AE:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
            return model, model.parameters()

    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert pretrain['arch'] in ['resnet', 'resnet_AE']

            model.load_state_dict(pretrain['state_dict'])

            model.module.fc = nn.Linear(model.module.fc.in_features,
                                        opt.n_finetune_classes)
            if not opt.no_cuda:
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
Esempio n. 21
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'resnetl', 'resnext', 'c3d', 'mobilenetv2', 'shufflenetv2',
        "mstcn"
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 50]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
    elif opt.model == 'resnetl':
        assert opt.model_depth in [10, 18]

        from models.resnetl import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnetl.resnetl10(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnetl.resnetl10(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)

    elif opt.model == 'resnext':
        assert opt.model_depth in [101]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'c3d':
        assert opt.model_depth in [10]

        from models.c3d import get_fine_tuning_parameters
        if opt.model_depth == 10:

            model = c3d.c3d_v1(sample_size=opt.sample_size,
                               sample_duration=opt.sample_duration,
                               num_classes=opt.n_classes)
    elif opt.model == 'mobilenetv2':

        from models.mobilenetv2 import get_fine_tuning_parameters
        model = mobilenetv2.mob_v2(num_classes=opt.n_classes,
                                   sample_size=opt.sample_size,
                                   width_mult=opt.width_mult)
    elif opt.model == 'shufflenetv2':

        from models.shufflenetv2 import get_fine_tuning_parameters
        model = shufflenetv2.shf_v2(num_classes=opt.n_classes,
                                    sample_size=opt.sample_size,
                                    width_mult=opt.width_mult)

    elif opt.model == "mstcn":
        model = MultiStageTemporalConvNet(
            embed_size=opt.embedding_dim,
            encoder=opt.tcn_encoder,
            n_classes=opt.n_classes,
            input_size=(opt.sample_size, opt.sample_size),
            input_channels=4 if opt.modality != "RGB" else 3,
            num_stages=opt.tcn_stages,
            causal_config=opt.tcn_causality,
            CTHW_layout=True,
            use_preprocessing=opt.use_preprocessing)

    if not opt.no_cuda:

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']
            if opt.pretrain_dataset == 'jester':
                if opt.sample_duration < 32 and opt.model != 'c3d':
                    model = _modify_first_conv_layer(model, 3, 3)
                if opt.model in ['mobilenetv2', 'shufflenetv2']:
                    del pretrain['state_dict']['module.classifier.1.weight']
                    del pretrain['state_dict']['module.classifier.1.bias']
                else:
                    del pretrain['state_dict']['module.fc.weight']
                    del pretrain['state_dict']['module.fc.bias']
                model.load_state_dict(pretrain['state_dict'], strict=False)

        if opt.modality in ['RGB', 'flo'] and opt.model != 'c3d':
            print("[INFO]: RGB model is used for init model")
            if opt.dataset != 'jester' and not opt.no_first_lay:
                model = _modify_first_conv_layer(
                    model, 3, 3)  ##### Check models trained (3,7,7) or (7,7,7)
        elif opt.modality in ['Depth', 'seg']:
            print(
                "[INFO]: Converting the pretrained model to Depth init model")
            model = _construct_depth_model(model)
            print("[INFO]: Done. Flow model ready.")
        elif opt.modality in ['RGB-D', 'RGB-flo', 'RGB-seg']:
            if opt.model != "mstcn":
                print(
                    "[INFO]: Converting the pretrained model to RGB+D init model"
                )
                model = _construct_rgbdepth_model(model)
                if opt.no_first_lay:
                    model = _modify_first_conv_layer(
                        model, 3,
                        4)  ##### Check models trained (3,7,7) or (7,7,7)
                print("[INFO]: Done. RGB-D model ready.")
        if opt.pretrain_dataset == opt.dataset:
            model.load_state_dict(pretrain['state_dict'])
        elif opt.pretrain_dataset in ['egogesture', 'nv', 'denso']:
            del pretrain['state_dict']['module.fc.weight']
            del pretrain['state_dict']['module.fc.bias']
            model.load_state_dict(pretrain['state_dict'], strict=False)

        # Check first kernel size
        if opt.model != "mstcn":
            modules = list(model.modules())
            first_conv_idx = list(
                filter(lambda x: isinstance(modules[x], nn.Conv3d),
                       list(range(len(modules)))))[0]

            conv_layer = modules[first_conv_idx]
            if conv_layer.kernel_size[0] > opt.sample_duration:
                print("[INFO]: RGB model is used for init model")
                model = _modify_first_conv_layer(model,
                                                 int(opt.sample_duration / 2),
                                                 1)

        if opt.model == 'c3d':  # CHECK HERE
            model.module.fc = nn.Linear(model.module.fc[0].in_features,
                                        model.module.fc[0].out_features)
            model.module.fc = model.module.fc.cuda()
        elif opt.model in ['mobilenetv2', 'shufflenetv2']:
            model.module.classifier = nn.Sequential(
                nn.Dropout(0.9),
                nn.Linear(model.module.classifier[1].in_features,
                          opt.n_finetune_classes))
            model.module.classifier = model.module.classifier.cuda()
        elif opt.model != "mstcn":
            model.module.fc = nn.Linear(model.module.fc.in_features,
                                        opt.n_finetune_classes)
            model.module.fc = model.module.fc.cuda()

        if opt.model != "mstcn":
            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
        else:
            parameters = model.trainable_parameters()

        model = nn.DataParallel(model, device_ids=None)

        model = model.cuda()
        return model, parameters

    else:
        print('ERROR no cuda')

    return model, model.parameters()
Esempio n. 22
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet',
        'mobilenet', 'mobilenetv2'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)

    elif opt.model == 'mobilenet':
        from models.mobilenet import get_fine_tuning_parameters
        model = mobilenet.get_model(num_classes=opt.n_classes,
                                    sample_size=opt.sample_size,
                                    width_mult=opt.width_mult)
    elif opt.model == 'mobilenetv2':
        from models.mobilenetv2 import get_fine_tuning_parameters
        model = mobilenetv2.get_model(num_classes=opt.n_classes,
                                      sample_size=opt.sample_size,
                                      width_mult=opt.width_mult)

    if not opt.no_cuda:
        if not opt.no_cuda_predict:
            model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            print("Pretrain arch", pretrain['arch'])
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])
            ft_begin_index = opt.ft_begin_index
            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.9),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
                model.module.classifier = model.module.classifier.cuda()
                ft_begin_index = 'complete' if ft_begin_index == 0 else 'last_layer'
            elif opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features,
                    opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()
            print("Finetuning at:", ft_begin_index)
            parameters = get_fine_tuning_parameters(model, ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])
            ft_begin_index = opt.ft_begin_index
            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.9),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
                model.module.classifier = model.module.classifier.cuda()
                ft_begin_index = 'complete' if ft_begin_index == 0 else 'last_layer'
            elif opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)
            print("Finetuning at:", ft_begin_index)
            parameters = get_fine_tuning_parameters(model, ft_begin_index)
            return model, parameters

    return model, model.parameters()
Esempio n. 23
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    # if opt.model == 'resnet':
    #     assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

    #     from models.resnet import get_fine_tuning_parameters

    #     if opt.model_depth == 10:
    #         model = resnet.resnet10(
    #             num_classes=opt.n_classes,
    #             shortcut_type=opt.resnet_shortcut,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)
    #     elif opt.model_depth == 18:
    #         model = resnet.resnet18(
    #             num_classes=opt.n_classes,
    #             shortcut_type=opt.resnet_shortcut,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)
    #     elif opt.model_depth == 34:
    #         model = resnet.resnet34(
    #             num_classes=opt.n_classes,
    #             shortcut_type=opt.resnet_shortcut,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)
    #     elif opt.model_depth == 50:
    #         model = resnet.resnet50(
    #             num_classes=opt.n_classes,
    #             shortcut_type=opt.resnet_shortcut,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)
    #     elif opt.model_depth == 101:
    #         model = resnet.resnet101(
    #             num_classes=opt.n_classes,
    #             shortcut_type=opt.resnet_shortcut,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)
    #     elif opt.model_depth == 152:
    #         model = resnet.resnet152(
    #             num_classes=opt.n_classes,
    #             shortcut_type=opt.resnet_shortcut,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)
    #     elif opt.model_depth == 200:
    #         model = resnet.resnet200(
    #             num_classes=opt.n_classes,
    #             shortcut_type=opt.resnet_shortcut,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)
    # elif opt.model == 'wideresnet':
    #     assert opt.model_depth in [50]

    #     from models.wide_resnet import get_fine_tuning_parameters

    #     if opt.model_depth == 50:
    #         model = wide_resnet.resnet50(
    #             num_classes=opt.n_classes,
    #             shortcut_type=opt.resnet_shortcut,
    #             k=opt.wide_resnet_k,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)
    # elif opt.model == 'resnext':
    #     assert opt.model_depth in [50, 101, 152]

    #     from models.resnext import get_fine_tuning_parameters

    #     if opt.model_depth == 50:
    #         model = resnext.resnet50(
    #             num_classes=opt.n_classes,
    #             shortcut_type=opt.resnet_shortcut,
    #             cardinality=opt.resnext_cardinality,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)
    #     elif opt.model_depth == 101:
    #         model = resnext.resnet101(
    #             num_classes=opt.n_classes,
    #             shortcut_type=opt.resnet_shortcut,
    #             cardinality=opt.resnext_cardinality,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)
    #     elif opt.model_depth == 152:
    #         model = resnext.resnet152(
    #             num_classes=opt.n_classes,
    #             shortcut_type=opt.resnet_shortcut,
    #             cardinality=opt.resnext_cardinality,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)
    # elif opt.model == 'preresnet':
    #     assert opt.model_depth in [18, 34, 50, 101, 152, 200]

    #     from models.pre_act_resnet import get_fine_tuning_parameters

    #     if opt.model_depth == 18:
    #         model = pre_act_resnet.resnet18(
    #             num_classes=opt.n_classes,
    #             shortcut_type=opt.resnet_shortcut,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)
    #     elif opt.model_depth == 34:
    #         model = pre_act_resnet.resnet34(
    #             num_classes=opt.n_classes,
    #             shortcut_type=opt.resnet_shortcut,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)
    #     elif opt.model_depth == 50:
    #         model = pre_act_resnet.resnet50(
    #             num_classes=opt.n_classes,
    #             shortcut_type=opt.resnet_shortcut,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)
    #     elif opt.model_depth == 101:
    #         model = pre_act_resnet.resnet101(
    #             num_classes=opt.n_classes,
    #             shortcut_type=opt.resnet_shortcut,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)
    #     elif opt.model_depth == 152:
    #         model = pre_act_resnet.resnet152(
    #             num_classes=opt.n_classes,
    #             shortcut_type=opt.resnet_shortcut,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)
    #     elif opt.model_depth == 200:
    #         model = pre_act_resnet.resnet200(
    #             num_classes=opt.n_classes,
    #             shortcut_type=opt.resnet_shortcut,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)
    # elif opt.model == 'densenet':
    #     assert opt.model_depth in [121, 169, 201, 264]

    #     from models.densenet import get_fine_tuning_parameters

    #     if opt.model_depth == 121:
    #         model = densenet.densenet121(
    #             num_classes=opt.n_classes,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)
    #     elif opt.model_depth == 169:
    #         model = densenet.densenet169(
    #             num_classes=opt.n_classes,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)
    #     elif opt.model_depth == 201:
    #         model = densenet.densenet201(
    #             num_classes=opt.n_classes,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)
    #     elif opt.model_depth == 264:
    #         model = densenet.densenet264(
    #             num_classes=opt.n_classes,
    #             sample_size=opt.sample_size,
    #             sample_duration=opt.sample_duration)

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        if opt.model_depth == 50:
            resnet = torchvision.models.resnet50(pretrained=True)
        elif opt.model_depth ==101:
            resnet = torchvision.models.resnet101(pretrained=True)
        else:
            raise Exception("Only resnet 50 and 101 pretrained on imagenet are provided now")
    else:
        raise Exception("currently only resnet50 and resnet101 are provided")
    if not opt.no_cuda:    
        
        i3resnet = I3ResNet(copy.deepcopy(resnet), opt.sample_duration)
        model = i3resnet.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            #assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features, opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(
                    model.classifier.in_features, opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                            opt.n_finetune_classes)

        
            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
Esempio n. 24
0
def generate_model(opt):
    assert opt.model in ['resnet', 'resnetl', 'resnext', 'c3d']

    if opt.model == 'resnet':
        assert opt.model_depth in [10]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
    elif opt.model == 'resnetl':
        assert opt.model_depth in [10]

        from models.resnetl import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnetl.resnetl10(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)

    elif opt.model == 'resnext':
        assert opt.model_depth in [101]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'c3d':
        assert opt.model_depth in [10]

        from models.c3d import get_fine_tuning_parameters
        if opt.model_depth == 10:

            model = c3d.c3d_v1(sample_size=opt.sample_size,
                               sample_duration=opt.sample_duration,
                               num_classes=opt.n_classes)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']
            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'c3d':  # CHECK HERE
                model.module.fc = nn.Linear(model.module.fc[0].in_features,
                                            opt.n_finetune_classes)
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
            model.module.fc = model.module.fc.cuda()

        if opt.modality == 'RGB' and opt.model != 'c3d':
            print("[INFO]: RGB model is used for init model")
            model = _modify_first_conv_layer(
                model, 3, 3)  ##### Check models trained (3,7,7) or (7,7,7)
        elif opt.modality == 'Depth':
            print(
                "[INFO]: Converting the pretrained model to Depth init model")
            model = _construct_depth_model(model)
            print("[INFO]: Done. Flow model ready.")
        elif opt.modality == 'RGB-D':
            print(
                "[INFO]: Converting the pretrained model to RGB+D init model")
            model = _construct_rgbdepth_model(model)
            print("[INFO]: Done. RGB-D model ready.")

        modules = list(model.modules())
        first_conv_idx = list(
            filter(lambda x: isinstance(modules[x], nn.Conv3d),
                   list(range(len(modules)))))[0]
        conv_layer = modules[first_conv_idx]
        if conv_layer.kernel_size[0] > opt.sample_duration:
            model = _modify_first_conv_layer(model,
                                             int(opt.sample_duration / 2), 1)

        parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
        return model, parameters

    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']
            model.load_state_dict(pretrain['state_dict'])

        if opt.modality == 'RGB' and opt.model != 'c3d':
            print("[INFO]: RGB model is used for init model")
            model = _modify_first_conv_layer(model, 3, 3)
        elif opt.modality == 'Depth':
            print(
                "[INFO]: Converting the pretrained model to Depth init model")
            model = _construct_depth_model(model)
            print("[INFO]: Deoth model ready.")
        elif opt.modality == 'RGB-D':
            print(
                "[INFO]: Converting the pretrained model to RGB-D init model")
            model = _construct_rgbdepth_model(model)
            print("[INFO]: Done. RGB-D model ready.")

        modules = list(model.modules())
        first_conv_idx = list(
            filter(lambda x: isinstance(modules[x], nn.Conv3d),
                   list(range(len(modules)))))[0]
        conv_layer = modules[first_conv_idx]
        if conv_layer.kernel_size[0] > opt.sample_duration:
            print("[INFO]: RGB model is used for init model")
            model = _modify_first_conv_layer(model,
                                             int(opt.sample_duration / 2), 1)

        if opt.model == 'c3d':  # CHECK HERE
            model.fc = nn.Linear(model.fc[0].in_features,
                                 model.fc[0].out_features)
        else:
            model.fc = nn.Linear(model.fc.in_features, opt.n_finetune_classes)

        parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
        return model, parameters
Esempio n. 25
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet','senet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                k=opt.wide_resnet_k,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)

    elif opt.model == 'senet':
         assert opt.model_depth in [50,101,152,154,5032,10132]
         if opt.model_depth == 50:
             model = senet.se_resnet50(num_classes = opt.n_classes,  pretrained = None)
         elif opt.model_depth == 101:
             model = senet.se_resnet101(num_classes = opt.n_classes, pretrained = None)
         elif opt.model_depth == 152:
             model = senet.se_resnet152(num_classes = opt.n_classes, pretrained = None)
         elif opt.model_depth == 154:
             model = senet.senet154(num_classes = opt.n_classes, pretrained = None)
         elif opt.model_depth == 5032:
             model = senet.resnext50_32x4d(num_classes = opt.n_classes, pretrained = None)
         elif opt.model_depth == 10132:
             model = senet.se_resnext101_32x4d(num_classes = opt.n_classes, pretrained = None)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            #assert opt.arch == pretrain['arch']
            #model.load_state_dict(pretrain['state_dict'])
            pretrained_dict = pretrain['state_dict']
            model_dict = model.state_dict()
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if k.find("module.fc") == -1}
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

            #model.load_state_dict(model_dict,strict=False)

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features, opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            elif opt.model == "senet":
                model.module.last_linear = nn.Linear(model.module.last_linear.in_features, opt.n_finetune_classes)
                model.module.last_linear = model.module.last_linear.cuda()
                return model, model.parameters()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(
                    model.classifier.in_features, opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                            opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
Esempio n. 26
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features,
                    opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path, map_location='cpu')
            assert opt.arch == pretrain['arch']

            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in pretrain['state_dict'].items():
                name = k[7:]  # remove `module.`
                new_state_dict[name] = v
            # load params
            model.load_state_dict(new_state_dict)

            #model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
Esempio n. 27
0
def generate_model(opt):
    assert opt.mode in ['score', 'feature']
    if opt.mode == 'score':
        last_fc = True
    elif opt.mode == 'feature':
        last_fc = False

    assert opt.model_name in ['resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet']

    if opt.model_name == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
        polices = resnet.get_fine_tuning_parameters(model, opt.ft_begin_index)
    elif opt.model_name == 'wideresnet':
        assert opt.model_depth in [50]

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
    elif opt.model_name == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                      last_fc=last_fc)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                      last_fc=last_fc)
    elif opt.model_name == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                            sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                            last_fc=last_fc)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                            sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                            last_fc=last_fc)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                            sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                            last_fc=last_fc)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                             sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                             last_fc=last_fc)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                             sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                             last_fc=last_fc)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                             sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                             last_fc=last_fc)
    elif opt.model_name == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
    return model, polices
Esempio n. 28
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet', 'i3d',
        'i3dv2'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == "i3d":

        from models.i3dpt import get_fine_tuning_parameters

        model = i3dpt.I3D(num_classes=opt.n_classes, dropout_prob=0.5)

    elif opt.model == "i3dv2":

        from models.I3D_Pytorch import get_fine_tuning_parameters

        model = I3D_Pytorch.I3D(num_classes=opt.n_classes,
                                dropout_keep_prob=0.5)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)

            if opt.model != "i3d" and opt.model != "i3dv2":
                assert opt.arch == pretrain['arch']
                model.load_state_dict(pretrain['state_dict'])
            else:
                pretrain = {"module." + k: v for k, v in pretrain.items()}
                model_dict = model.state_dict()
                model_dict.update(pretrain)
                model.load_state_dict(model_dict)

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features,
                    opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
Esempio n. 29
0
def generate_model(config):
    assert config.base_model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]
    print(config.base_model, config.n_classes)

    if config.base_model == 'resnet':

        from models.resnet import get_fine_tuning_parameters
        create_resnet_model = get_resnet_model(config.model_depth)
        base_model = create_resnet_model(
            num_classes=config.n_classes,
            shortcut_type=config.resnet_shortcut,
            sample_size=config.sample_size,
            sample_duration=config.sample_duration)

    elif config.base_model == 'wideresnet':
        assert config.model_depth in [50], f'Wideresnet model depth {config.model_depth} is not valid.' \
                                           f' Acceptable: {WIDERESNET_DEPTHS}'

        from models.wide_resnet import get_fine_tuning_parameters

        if config.model_depth == 50:
            base_model = wide_resnet.resnet50(
                num_classes=config.n_classes,
                shortcut_type=config.resnet_shortcut,
                k=config.wide_resnet_k,
                sample_size=config.sample_size,
                sample_duration=config.sample_duration)

    elif config.base_model == 'resnext':

        from models.resnext import get_fine_tuning_parameters
        create_resnext_model = get_resnext_model(config.model_depth)
        base_model = create_resnext_model(
            num_classes=config.n_classes,
            shortcut_type=config.resnet_shortcut,
            cardinality=config.resnext_cardinality,
            sample_size=config.sample_size,
            sample_duration=config.sample_duration)

    elif config.base_model == 'preresnet':

        from models.pre_act_resnet import get_fine_tuning_parameters
        create_preresnet_model = get_preresnet_model(config.model_depth)
        base_model = create_preresnet_model(
            num_classes=config.n_classes,
            shortcut_type=config.resnet_shortcut,
            sample_size=config.sample_size,
            sample_duration=config.sample_duration)

    elif config.base_model == 'densenet':
        from models.densenet import get_fine_tuning_parameters
        create_densenet_model = get_densenet_model(config.model_depth)
        base_model = create_densenet_model(
            num_classes=config.n_classes,
            sample_size=config.sample_size,
            sample_duration=config.sample_duration)

    if config.cuda_available:
        if config.cuda_id1 != -1:
            base_model = nn.DataParallel(
                base_model, device_ids=[config.cuda_id0, config.cuda_id1])
        else:
            base_model = nn.DataParallel(base_model,
                                         device_ids=[config.cuda_id0])
    if config.pretrain_path:
        base_model = load_pretrained(base_model, config.base_model,
                                     config.model_depth, config.pretrain_path)

    if config.use_embeddings:
        model = EmbeddingModel(base_model, config)
    else:
        model = update_with_finetune_block(base_model, config)
    parameters = get_fine_tuning_parameters(
        model, config.ft_begin_index
    )  # should not fail due to assert in the beginning

    return model, parameters
Esempio n. 30
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'resnext'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                isSource = opt.isSource,
                transfer_module = opt.transfer_module,
                sourceKind = opt.sourceKind,
                layer_num = opt.layer_num,
                multi_source = opt.multi_source)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                isSource = opt.isSource,
                transfer_module = opt.transfer_module,
                sourceKind = opt.sourceKind,
                layer_num = opt.layer_num,
                multi_source = opt.multi_source)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                isSource = opt.isSource,
                transfer_module = opt.transfer_module,
                sourceKind = opt.sourceKind,
                layer_num = opt.layer_num,
                multi_source = opt.multi_source)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                isSource = opt.isSource,
                transfer_module = opt.transfer_module,
                sourceKind = opt.sourceKind,
                layer_num = opt.layer_num,
                multi_source = opt.multi_source)
        elif opt.model_depth == 101:
            model = resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                isSource = opt.isSource,
                transfer_module = opt.transfer_module,
                sourceKind = opt.sourceKind,
                layer_num = opt.layer_num,
                multi_source = opt.multi_source)
        elif opt.model_depth == 152:
            model = resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                isSource = opt.isSource,
                transfer_module = opt.transfer_module,
                sourceKind = opt.sourceKind,
                layer_num = opt.layer_num,
                multi_source = opt.multi_source)
        elif opt.model_depth == 200:
            model = resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                isSource = opt.isSource,
                transfer_module = opt.transfer_module,
                sourceKind = opt.sourceKind,
                layer_num = opt.layer_num,
                multi_source = opt.multi_source)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                isSource = opt.isSource,
                transfer_module = opt.transfer_module,
                sourceKind = opt.sourceKind,
                layer_num = opt.layer_num,
                multi_source = opt.multi_source)
        elif opt.model_depth == 101:
            model = resnext.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                isSource = opt.isSource,
                transfer_module = opt.transfer_module,
                sourceKind = opt.sourceKind,
                layer_num = opt.layer_num,
                multi_source = opt.multi_source)
        elif opt.model_depth == 152:
            model = resnext.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration,
                isSource = opt.isSource,
                transfer_module = opt.transfer_module,
                sourceKind = opt.sourceKind,
                layer_num = opt.layer_num,
                multi_source = opt.multi_source)
    print(opt.no_cuda)
    print(type(opt.no_cuda))
    if not opt.no_cuda:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            print('loading pretrained model arch', pretrain['arch'], opt.arch)
            assert opt.arch == pretrain['arch']

            pretrained_dict = pretrain['state_dict']
            model_dict = model.state_dict()
            # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
            pretrained_dict = {str.replace(k,'module.',''): v for k,v in pretrained_dict.items()}
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)
            model = model.cuda()
            model = nn.DataParallel(model, device_ids=None)
            if opt.inference == False:
               
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

                parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
                
                print(model)
                return model, parameters
            elif opt.inference:
                model = model.cuda()
                model = nn.DataParallel(model, device_ids=None)
                return model, model.parameters()
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            print('loading pretrained model arch', pretrain['arch'])
            pretrain = torch.load(opt.pretrain_path)
            
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])


            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            model = model.cuda()
            model = nn.DataParallel(model, device_ids=None)
            return model, parameters

    return model, model.parameters()
Esempio n. 31
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                k=opt.wide_resnet_k,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features, opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(
                    model.classifier.in_features, opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                            opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()