Пример #1
0
def generate_model(opt):
    assert opt.model in [
        'c3d', 'squeezenet', 'mobilenet', 'resnext', 'resnet', 'resnetl',
        'shufflenet', 'mobilenetv2', 'shufflenetv2'
    ]

    if opt.model == 'c3d':
        from models.c3d import get_fine_tuning_parameters
        model = c3d.get_model(num_classes=opt.n_classes,
                              sample_size=opt.sample_size,
                              sample_duration=opt.sample_duration)
    elif opt.model == 'squeezenet':
        from models.squeezenet import get_fine_tuning_parameters
        model = squeezenet.get_model(version=opt.version,
                                     num_classes=opt.n_classes,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'shufflenet':
        from models.shufflenet import get_fine_tuning_parameters
        model = shufflenet.get_model(groups=opt.groups,
                                     width_mult=opt.width_mult,
                                     num_classes=opt.n_classes)
    elif opt.model == 'shufflenetv2':
        from models.shufflenetv2 import get_fine_tuning_parameters
        model = shufflenetv2.get_model(num_classes=opt.n_classes,
                                       sample_size=opt.sample_size,
                                       width_mult=opt.width_mult)
    elif opt.model == 'mobilenet':
        from models.mobilenet import get_fine_tuning_parameters
        model = mobilenet.get_model(num_classes=opt.n_classes,
                                    sample_size=opt.sample_size,
                                    width_mult=opt.width_mult)
    elif opt.model == 'mobilenetv2':
        from models.mobilenetv2 import get_fine_tuning_parameters
        model = mobilenetv2.get_model(num_classes=opt.n_classes,
                                      sample_size=opt.sample_size,
                                      width_mult=opt.width_mult)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]
        from models.resnext import get_fine_tuning_parameters
        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'resnetl':
        assert opt.model_depth in [10]

        from models.resnetl import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnetl.resnetl10(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]
        from models.resnet import get_fine_tuning_parameters
        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)
        pytorch_total_params = sum(p.numel() for p in model.parameters()
                                   if p.requires_grad)
        print("Total number of trainable parameters: ", pytorch_total_params)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path,
                                  map_location=torch.device('cpu'))
            # print(opt.arch)
            # print(pretrain['arch'])
            # assert opt.arch == pretrain['arch']
            model.load_state_dict(pretrain['state_dict'])

            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.5),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
                model.module.classifier = model.module.classifier.cuda()
            elif opt.model == 'squeezenet':
                model.module.classifier = nn.Sequential(
                    nn.Dropout(p=0.5),
                    nn.Conv3d(model.module.classifier[1].in_channels,
                              opt.n_finetune_classes,
                              kernel_size=1), nn.ReLU(inplace=True),
                    nn.AvgPool3d((1, 4, 4), stride=1))
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            # model = _modify_first_conv_layer(model)
            # model = model.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_portion)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']
            model.load_state_dict(pretrain['state_dict'])

            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.9),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
            elif opt.model == 'squeezenet':
                model.module.classifier = nn.Sequential(
                    nn.Dropout(p=0.5),
                    nn.Conv3d(model.module.classifier[1].in_channels,
                              opt.n_finetune_classes,
                              kernel_size=1), nn.ReLU(inplace=True),
                    nn.AvgPool3d((1, 4, 4), stride=1))
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
Пример #2
0
def generate_model(opt):
    assert opt.model in [
        'c3d', 'squeezenet', 'mobilenet', 'resnext', 'resnet', 'resnetl',
        'shufflenet', 'mobilenetv2', 'shufflenetv2'
    ]

    if opt.model == 'c3d':
        from models.c3d import get_fine_tuning_parameters
        model = c3d.get_model(num_classes=opt.n_classes,
                              sample_size=opt.sample_size,
                              sample_duration=opt.sample_duration)
    elif opt.model == 'squeezenet':
        from models.squeezenet import get_fine_tuning_parameters
        model = squeezenet.get_model(version=opt.version,
                                     num_classes=opt.n_classes,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'shufflenet':
        from models.shufflenet import get_fine_tuning_parameters
        model = shufflenet.get_model(groups=opt.groups,
                                     width_mult=opt.width_mult,
                                     num_classes=opt.n_classes)
    elif opt.model == 'shufflenetv2':
        from models.shufflenetv2 import get_fine_tuning_parameters
        model = shufflenetv2.get_model(num_classes=opt.n_classes,
                                       sample_size=opt.sample_size,
                                       width_mult=opt.width_mult)
    elif opt.model == 'mobilenet':
        from models.mobilenet import get_fine_tuning_parameters
        model = mobilenet.get_model(num_classes=opt.n_classes,
                                    sample_size=opt.sample_size,
                                    width_mult=opt.width_mult)
    elif opt.model == 'mobilenetv2':
        from models.mobilenetv2 import get_fine_tuning_parameters
        model = mobilenetv2.get_model(num_classes=opt.n_classes,
                                      sample_size=opt.sample_size,
                                      width_mult=opt.width_mult)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]
        from models.resnext import get_fine_tuning_parameters
        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'resnetl':
        assert opt.model_depth in [10]

        from models.resnetl import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnetl.resnetl10(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]
        from models.resnet import get_fine_tuning_parameters
        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)
        pytorch_total_params = sum(p.numel() for p in model.parameters()
                                   if p.requires_grad)
        print("Total number of trainable parameters: ", pytorch_total_params)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path,
                                  map_location=torch.device('cpu'))
            # print(opt.arch)
            # print(pretrain['arch'])
            # assert opt.arch == pretrain['arch']
            model.load_state_dict(pretrain['state_dict'])

            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.5),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
                model.module.classifier = model.module.classifier.cuda()
            elif opt.model == 'squeezenet':
                model.module.classifier = nn.Sequential(
                    nn.Dropout(p=0.5),
                    nn.Conv3d(model.module.classifier[1].in_channels,
                              opt.n_finetune_classes,
                              kernel_size=1), nn.ReLU(inplace=True),
                    nn.AvgPool3d((1, 4, 4), stride=1))
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()
        if opt.modality == 'RGB' and opt.model != 'c3d':
            print("[INFO]: RGB model is used for init model")
            model = _modify_first_conv_layer(
                model, 3, 3)  ##### Check models trained (3,7,7) or (7,7,7)
        elif opt.modality == 'Depth':
            print(
                "[INFO]: Converting the pretrained model to Depth init model")
            model = _construct_depth_model(model)
            print("[INFO]: Done. Flow model ready.")
        elif opt.modality == 'RGB-D':
            print(
                "[INFO]: Converting the pretrained model to RGB+D init model")
            model = _construct_rgbdepth_model(model)
            print("[INFO]: Done. RGB-D model ready.")

        modules = list(model.modules())
        first_conv_idx = list(
            filter(lambda x: isinstance(modules[x], nn.Conv3d),
                   list(range(len(modules)))))[0]
        conv_layer = modules[first_conv_idx]
        if conv_layer.kernel_size[0] > opt.sample_duration:
            model = _modify_first_conv_layer(model,
                                             int(opt.sample_duration / 2), 1)

        parameters = get_fine_tuning_parameters(model, opt.ft_portion)
        return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']
            model.load_state_dict(pretrain['state_dict'])

            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.9),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
            elif opt.model == 'squeezenet':
                model.module.classifier = nn.Sequential(
                    nn.Dropout(p=0.5),
                    nn.Conv3d(model.module.classifier[1].in_channels,
                              opt.n_finetune_classes,
                              kernel_size=1), nn.ReLU(inplace=True),
                    nn.AvgPool3d((1, 4, 4), stride=1))
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)

        if opt.modality == 'RGB' and opt.model != 'c3d':
            print("[INFO]: RGB model is used for init model")
            model = _modify_first_conv_layer(model, 3, 3)
        elif opt.modality == 'Depth':
            print(
                "[INFO]: Converting the pretrained model to Depth init model")
            model = _construct_depth_model(model)
            print("[INFO]: Deoth model ready.")
        elif opt.modality == 'RGB-D':
            print(
                "[INFO]: Converting the pretrained model to RGB-D init model")
            model = _construct_rgbdepth_model(model)
            print("[INFO]: Done. RGB-D model ready.")

        modules = list(model.modules())
        first_conv_idx = list(
            filter(lambda x: isinstance(modules[x], nn.Conv3d),
                   list(range(len(modules)))))[0]
        conv_layer = modules[first_conv_idx]
        if conv_layer.kernel_size[0] > opt.sample_duration:
            print("[INFO]: RGB model is used for init model")
            model = _modify_first_conv_layer(model,
                                             int(opt.sample_duration / 2), 1)

        if opt.model == 'c3d':  # CHECK HERE
            model.fc = nn.Linear(model.fc[0].in_features,
                                 model.fc[0].out_features)
        else:
            model.fc = nn.Linear(model.fc.in_features, opt.n_finetune_classes)

        parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
        return model, parameters

    return model, model.parameters()
Пример #3
0
def generate_model(args):
    assert args.model_type in [
        'resnet', 'shufflenet', 'shufflenetv2', 'mobilenet', 'mobilenetv2'
    ]
    if args.pre_train_model == False or args.mode == 'test':
        print('Without Pre-trained model')
        if args.model_type == 'resnet':
            assert args.model_depth in [18, 50, 101]
            if args.model_depth == 18:

                model = resnet.resnet18(output_dim=args.feature_dim,
                                        sample_size=args.sample_size,
                                        sample_duration=args.sample_duration,
                                        shortcut_type=args.shortcut_type,
                                        tracking=args.tracking,
                                        pre_train=args.pre_train_model)

            elif args.model_depth == 50:
                model = resnet.resnet50(output_dim=args.feature_dim,
                                        sample_size=args.sample_size,
                                        sample_duration=args.sample_duration,
                                        shortcut_type=args.shortcut_type,
                                        tracking=args.tracking,
                                        pre_train=args.pre_train_model)

            elif args.model_depth == 101:
                model = resnet.resnet101(output_dim=args.feature_dim,
                                         sample_size=args.sample_size,
                                         sample_duration=args.sample_duration,
                                         shortcut_type=args.shortcut_type,
                                         tracking=args.tracking,
                                         pre_train=args.pre_train_model)

        elif args.model_type == 'shufflenet':
            model = shufflenet.get_model(groups=args.groups,
                                         width_mult=args.width_mult,
                                         output_dim=args.feature_dim,
                                         pre_train=args.pre_train_model)
        elif args.model_type == 'shufflenetv2':
            model = shufflenetv2.get_model(output_dim=args.feature_dim,
                                           sample_size=args.sample_size,
                                           width_mult=args.width_mult,
                                           pre_train=args.pre_train_model)
        elif args.model_type == 'mobilenet':
            model = mobilenet.get_model(sample_size=args.sample_size,
                                        width_mult=args.width_mult,
                                        pre_train=args.pre_train_model)
        elif args.model_type == 'mobilenetv2':
            model = mobilenetv2.get_model(sample_size=args.sample_size,
                                          width_mult=args.width_mult,
                                          pre_train=args.pre_train_model)

        model = nn.DataParallel(model, device_ids=None)
    else:
        if args.model_type == 'resnet':
            pre_model_path = './premodels/kinetics_resnet_' + str(
                args.model_depth) + '_RGB_16_best.pth'
            ###default pre-trained model is trained on kinetics dataset which has 600 classes
            if args.model_depth == 18:
                model = resnet.resnet18(output_dim=args.feature_dim,
                                        sample_size=args.sample_size,
                                        sample_duration=args.sample_duration,
                                        shortcut_type='A',
                                        tracking=args.tracking,
                                        pre_train=args.pre_train_model)

            elif args.model_depth == 50:
                model = resnet.resnet50(output_dim=args.feature_dim,
                                        sample_size=args.sample_size,
                                        sample_duration=args.sample_duration,
                                        shortcut_type='B',
                                        tracking=args.tracking,
                                        pre_train=args.pre_train_model)

            elif args.model_depth == 101:
                model = resnet.resnet101(output_dim=args.feature_dim,
                                         sample_size=args.sample_size,
                                         sample_duration=args.sample_duration,
                                         shortcut_type='B',
                                         tracking=args.tracking,
                                         pre_train=args.pre_train_model)

        elif args.model_type == 'shufflenet':
            pre_model_path = './premodels/kinetics_shufflenet_' + str(
                args.width_mult) + 'x_G3_RGB_16_best.pth'
            model = shufflenet.get_model(groups=args.groups,
                                         width_mult=args.width_mult,
                                         output_dim=args.feature_dim,
                                         pre_train=args.pre_train_model)

        elif args.model_type == 'shufflenetv2':
            pre_model_path = './premodels/kinetics_shufflenetv2_' + str(
                args.width_mult) + 'x_RGB_16_best.pth'
            model = shufflenetv2.get_model(output_dim=args.feature_dim,
                                           sample_size=args.sample_size,
                                           width_mult=args.width_mult,
                                           pre_train=args.pre_train_model)
        elif args.model_type == 'mobilenet':
            pre_model_path = './premodels/kinetics_mobilenet_' + str(
                args.width_mult) + 'x_RGB_16_best.pth'
            model = mobilenet.get_model(sample_size=args.sample_size,
                                        width_mult=args.width_mult,
                                        pre_train=args.pre_train_model)
        elif args.model_type == 'mobilenetv2':
            pre_model_path = './premodels/kinetics_mobilenetv2_' + str(
                args.width_mult) + 'x_RGB_16_best.pth'
            model = mobilenetv2.get_model(sample_size=args.sample_size,
                                          width_mult=args.width_mult,
                                          pre_train=args.pre_train_model)

        model = nn.DataParallel(
            model, device_ids=None)  # in order to load pre-trained model
        model_dict = model.state_dict()
        pretrained_dict = torch.load(pre_model_path)['state_dict']
        #print(len(pretrained_dict.keys()))
        #print({k for k, v in pretrained_dict.items() if k not in model_dict})
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        #print(len(pretrained_dict.keys()))
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        model = _construct_depth_model(model)
    if args.use_cuda:
        model = model.cuda()
    return model
Пример #4
0
def generate_model(opt):
    assert opt.model in [
        'c3d', 'squeezenet', 'mobilenet', 'shufflenet', 'mobilenetv2',
        'shufflenetv2'
    ]

    if opt.model == 'c3d':
        from models.c3d import get_fine_tuning_parameters
        model = c3d.get_model(num_classes=opt.n_classes,
                              sample_size=opt.sample_size,
                              sample_duration=opt.sample_duration)
    elif opt.model == 'squeezenet':
        from models.squeezenet import get_fine_tuning_parameters
        model = squeezenet.get_model(version=opt.version,
                                     num_classes=opt.n_classes,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'shufflenet':
        from models.shufflenet import get_fine_tuning_parameters
        model = shufflenet.get_model(groups=opt.groups,
                                     width_mult=opt.width_mult,
                                     num_classes=opt.n_classes)
    elif opt.model == 'shufflenetv2':
        from models.shufflenetv2 import get_fine_tuning_parameters
        model = shufflenetv2.get_model(num_classes=opt.n_classes,
                                       sample_size=opt.sample_size,
                                       width_mult=opt.width_mult)
    elif opt.model == 'mobilenet':
        from models.mobilenet import get_fine_tuning_parameters
        model = mobilenet.get_model(num_classes=opt.n_classes,
                                    sample_size=opt.sample_size,
                                    width_mult=opt.width_mult)
    elif opt.model == 'mobilenetv2':
        from models.mobilenetv2 import get_fine_tuning_parameters
        model = mobilenetv2.get_model(num_classes=opt.n_classes,
                                      sample_size=opt.sample_size,
                                      width_mult=opt.width_mult)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)
        pytorch_total_params = sum(p.numel() for p in model.parameters()
                                   if p.requires_grad)
        print("Total number of trainable parameters: ", pytorch_total_params)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path,
                                  map_location=torch.device('cpu'))
            model_dict = model.state_dict()
            pretrain_dict = {
                k: v
                for k, v in pretrain['state_dict'].items() if k in model_dict
            }
            model_dict.update(pretrain_dict)
            model.load_state_dict(model_dict)

            # if opt.model in  ['mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2']:
            #     model.module.classifier = nn.Sequential(
            #                     nn.Dropout(0.8),
            #                     nn.Linear(model.module.classifier[1].in_features, opt.n_finetune_classes))
            #     model.module.classifier = model.module.classifier.cuda()
            # elif opt.model == 'squeezenet':
            #     model.module.classifier = nn.Sequential(
            #                     nn.Dropout(p=0.5),
            #                     nn.Conv3d(model.module.classifier[1].in_channels, opt.n_finetune_classes, kernel_size=1),
            #                     nn.ReLU(inplace=True),
            #                     nn.AvgPool3d((1,4,4), stride=1))
            #     model.module.classifier = model.module.classifier.cuda()
            # else:
            #     model.module.fc = nn.Linear(model.module.fc.in_features, opt.n_finetune_classes)
            #     model.module.fc = model.module.fc.cuda()

            model = _modify_first_conv_layer(model)
            model = model.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_portion)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']
            model.load_state_dict(pretrain['state_dict'])

            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.9),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
            elif opt.model == 'squeezenet':
                model.module.classifier = nn.Sequential(
                    nn.Dropout(p=0.5),
                    nn.Conv3d(model.module.classifier[1].in_channels,
                              opt.n_finetune_classes,
                              kernel_size=1), nn.ReLU(inplace=True),
                    nn.AvgPool3d((1, 4, 4), stride=1))
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()