Exemple #1
0
def generate_model(opt):
    assert opt.model in [
        'c3d', 'squeezenet', 'mobilenet', 'resnext', 'resnet', 'resnetl',
        'shufflenet', 'mobilenetv2', 'shufflenetv2'
    ]

    if opt.model == 'c3d':
        from models.c3d import get_fine_tuning_parameters
        model = c3d.get_model(num_classes=opt.n_classes,
                              sample_size=opt.sample_size,
                              sample_duration=opt.sample_duration)
    elif opt.model == 'squeezenet':
        from models.squeezenet import get_fine_tuning_parameters
        model = squeezenet.get_model(version=opt.version,
                                     num_classes=opt.n_classes,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'shufflenet':
        from models.shufflenet import get_fine_tuning_parameters
        model = shufflenet.get_model(groups=opt.groups,
                                     width_mult=opt.width_mult,
                                     num_classes=opt.n_classes)
    elif opt.model == 'shufflenetv2':
        from models.shufflenetv2 import get_fine_tuning_parameters
        model = shufflenetv2.get_model(num_classes=opt.n_classes,
                                       sample_size=opt.sample_size,
                                       width_mult=opt.width_mult)
    elif opt.model == 'mobilenet':
        from models.mobilenet import get_fine_tuning_parameters
        model = mobilenet.get_model(num_classes=opt.n_classes,
                                    sample_size=opt.sample_size,
                                    width_mult=opt.width_mult)
    elif opt.model == 'mobilenetv2':
        from models.mobilenetv2 import get_fine_tuning_parameters
        model = mobilenetv2.get_model(num_classes=opt.n_classes,
                                      sample_size=opt.sample_size,
                                      width_mult=opt.width_mult)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]
        from models.resnext import get_fine_tuning_parameters
        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'resnetl':
        assert opt.model_depth in [10]

        from models.resnetl import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnetl.resnetl10(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]
        from models.resnet import get_fine_tuning_parameters
        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)
        pytorch_total_params = sum(p.numel() for p in model.parameters()
                                   if p.requires_grad)
        print("Total number of trainable parameters: ", pytorch_total_params)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path,
                                  map_location=torch.device('cpu'))
            # print(opt.arch)
            # print(pretrain['arch'])
            # assert opt.arch == pretrain['arch']
            model.load_state_dict(pretrain['state_dict'])

            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.5),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
                model.module.classifier = model.module.classifier.cuda()
            elif opt.model == 'squeezenet':
                model.module.classifier = nn.Sequential(
                    nn.Dropout(p=0.5),
                    nn.Conv3d(model.module.classifier[1].in_channels,
                              opt.n_finetune_classes,
                              kernel_size=1), nn.ReLU(inplace=True),
                    nn.AvgPool3d((1, 4, 4), stride=1))
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            # model = _modify_first_conv_layer(model)
            # model = model.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_portion)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']
            model.load_state_dict(pretrain['state_dict'])

            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.9),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
            elif opt.model == 'squeezenet':
                model.module.classifier = nn.Sequential(
                    nn.Dropout(p=0.5),
                    nn.Conv3d(model.module.classifier[1].in_channels,
                              opt.n_finetune_classes,
                              kernel_size=1), nn.ReLU(inplace=True),
                    nn.AvgPool3d((1, 4, 4), stride=1))
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
Exemple #2
0
def generate_model(opt):
    assert opt.model in [
        'c3d', 'squeezenet', 'mobilenet', 'resnext', 'resnet', 'shufflenet',
        'mobilenetv2', 'shufflenetv2', 'slow_mobilenetv2', 'fast_mobilenetv2',
        'slow_fast_mobilenetv2'
    ]

    if opt.model == 'mobilenet':
        from models.mobilenet import get_fine_tuning_parameters
        model = mobilenet.get_model(num_classes=opt.n_classes,
                                    sample_size=opt.sample_size,
                                    width_mult=opt.width_mult)
    elif opt.model == 'mobilenetv2':
        from models.mobilenetv2 import get_fine_tuning_parameters
        model = mobilenetv2.get_model(num_classes=opt.n_classes,
                                      sample_size=opt.sample_size,
                                      width_mult=opt.width_mult)
    elif opt.model == 'fast_mobilenetv2':
        from models.fast_mobilenetv2 import get_fine_tuning_parameters
        model = fast_mobilenetv2.get_model(num_classes=opt.n_classes,
                                           sample_size=opt.sample_size,
                                           width_mult=opt.width_mult)
    elif opt.model == 'slow_mobilenetv2':
        from models.slow_mobilenetv2 import get_fine_tuning_parameters
        model = slow_mobilenetv2.get_model(num_classes=opt.n_classes,
                                           sample_size=opt.sample_size,
                                           width_mult=opt.width_mult)
    elif opt.model == 'slow_fast_mobilenetv2':
        from models.slow_fast_mobilenetv2 import get_fine_tuning_parameters
        model = slow_fast_mobilenetv2.get_model(
            num_classes=opt.n_classes,
            sample_size=opt.sample_size,
            width_mult_slow=opt.width_mult_slow,
            beta=opt.beta,
            fusion_kernel_size=opt.fusion_kernel_size,
            fusion_conv_channel_ratio=opt.fusion_conv_channel_ratio,
            slow_frames=opt.slow_frames,
            fast_frames=opt.fast_frames,
            lateral_connection_section_indices=opt.
            lateral_connection_section_indices)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)
        pytorch_total_params = sum(p.numel() for p in model.parameters()
                                   if p.requires_grad)
        assert torch.cuda.is_available()
        print("Using GPU {}".format(torch.cuda.get_device_name(0)))
        print("Total number of trainable parameters: ", pytorch_total_params)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path,
                                  map_location=torch.device('cpu'))
            model.load_state_dict(pretrain['state_dict'])

            if opt.inference:
                return model, []

            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2',
                    'slow_fast_mobilenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.9),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
                model.module.classifier = model.module.classifier.cuda()
            elif opt.model == 'squeezenet':
                model.module.classifier = nn.Sequential(
                    nn.Dropout(p=0.5),
                    nn.Conv3d(model.module.classifier[1].in_channels,
                              opt.n_finetune_classes,
                              kernel_size=1), nn.ReLU(inplace=True),
                    nn.AvgPool3d((1, 4, 4), stride=1))
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_portion)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            model.load_state_dict(pretrain['state_dict'])

            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2',
                    ''
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.9),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
            elif opt.model == 'squeezenet':
                model.module.classifier = nn.Sequential(
                    nn.Dropout(p=0.5),
                    nn.Conv3d(model.module.classifier[1].in_channels,
                              opt.n_finetune_classes,
                              kernel_size=1), nn.ReLU(inplace=True),
                    nn.AvgPool3d((1, 4, 4), stride=1))
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    if opt.resume_path:
        if not opt.no_cuda:
            model = model.cuda()
        model_state = torch.load(opt.resume_path,
                                 map_location=torch.device('cpu'))
        model.load_state_dict(model_state['state_dict'])

    return model, model.parameters()
def generate_model(args):
    assert args.model_type in [
        'resnet', 'shufflenet', 'shufflenetv2', 'mobilenet', 'mobilenetv2'
    ]
    if args.pre_train_model == False or args.mode == 'test':
        print('Without Pre-trained model')
        if args.model_type == 'resnet':
            assert args.model_depth in [18, 50, 101]
            if args.model_depth == 18:

                model = resnet.resnet18(output_dim=args.feature_dim,
                                        sample_size=args.sample_size,
                                        sample_duration=args.sample_duration,
                                        shortcut_type=args.shortcut_type,
                                        tracking=args.tracking,
                                        pre_train=args.pre_train_model)

            elif args.model_depth == 50:
                model = resnet.resnet50(output_dim=args.feature_dim,
                                        sample_size=args.sample_size,
                                        sample_duration=args.sample_duration,
                                        shortcut_type=args.shortcut_type,
                                        tracking=args.tracking,
                                        pre_train=args.pre_train_model)

            elif args.model_depth == 101:
                model = resnet.resnet101(output_dim=args.feature_dim,
                                         sample_size=args.sample_size,
                                         sample_duration=args.sample_duration,
                                         shortcut_type=args.shortcut_type,
                                         tracking=args.tracking,
                                         pre_train=args.pre_train_model)

        elif args.model_type == 'shufflenet':
            model = shufflenet.get_model(groups=args.groups,
                                         width_mult=args.width_mult,
                                         output_dim=args.feature_dim,
                                         pre_train=args.pre_train_model)
        elif args.model_type == 'shufflenetv2':
            model = shufflenetv2.get_model(output_dim=args.feature_dim,
                                           sample_size=args.sample_size,
                                           width_mult=args.width_mult,
                                           pre_train=args.pre_train_model)
        elif args.model_type == 'mobilenet':
            model = mobilenet.get_model(sample_size=args.sample_size,
                                        width_mult=args.width_mult,
                                        pre_train=args.pre_train_model)
        elif args.model_type == 'mobilenetv2':
            model = mobilenetv2.get_model(sample_size=args.sample_size,
                                          width_mult=args.width_mult,
                                          pre_train=args.pre_train_model)

        model = nn.DataParallel(model, device_ids=None)
    else:
        if args.model_type == 'resnet':
            pre_model_path = './premodels/kinetics_resnet_' + str(
                args.model_depth) + '_RGB_16_best.pth'
            ###default pre-trained model is trained on kinetics dataset which has 600 classes
            if args.model_depth == 18:
                model = resnet.resnet18(output_dim=args.feature_dim,
                                        sample_size=args.sample_size,
                                        sample_duration=args.sample_duration,
                                        shortcut_type='A',
                                        tracking=args.tracking,
                                        pre_train=args.pre_train_model)

            elif args.model_depth == 50:
                model = resnet.resnet50(output_dim=args.feature_dim,
                                        sample_size=args.sample_size,
                                        sample_duration=args.sample_duration,
                                        shortcut_type='B',
                                        tracking=args.tracking,
                                        pre_train=args.pre_train_model)

            elif args.model_depth == 101:
                model = resnet.resnet101(output_dim=args.feature_dim,
                                         sample_size=args.sample_size,
                                         sample_duration=args.sample_duration,
                                         shortcut_type='B',
                                         tracking=args.tracking,
                                         pre_train=args.pre_train_model)

        elif args.model_type == 'shufflenet':
            pre_model_path = './premodels/kinetics_shufflenet_' + str(
                args.width_mult) + 'x_G3_RGB_16_best.pth'
            model = shufflenet.get_model(groups=args.groups,
                                         width_mult=args.width_mult,
                                         output_dim=args.feature_dim,
                                         pre_train=args.pre_train_model)

        elif args.model_type == 'shufflenetv2':
            pre_model_path = './premodels/kinetics_shufflenetv2_' + str(
                args.width_mult) + 'x_RGB_16_best.pth'
            model = shufflenetv2.get_model(output_dim=args.feature_dim,
                                           sample_size=args.sample_size,
                                           width_mult=args.width_mult,
                                           pre_train=args.pre_train_model)
        elif args.model_type == 'mobilenet':
            pre_model_path = './premodels/kinetics_mobilenet_' + str(
                args.width_mult) + 'x_RGB_16_best.pth'
            model = mobilenet.get_model(sample_size=args.sample_size,
                                        width_mult=args.width_mult,
                                        pre_train=args.pre_train_model)
        elif args.model_type == 'mobilenetv2':
            pre_model_path = './premodels/kinetics_mobilenetv2_' + str(
                args.width_mult) + 'x_RGB_16_best.pth'
            model = mobilenetv2.get_model(sample_size=args.sample_size,
                                          width_mult=args.width_mult,
                                          pre_train=args.pre_train_model)

        model = nn.DataParallel(
            model, device_ids=None)  # in order to load pre-trained model
        model_dict = model.state_dict()
        pretrained_dict = torch.load(pre_model_path)['state_dict']
        #print(len(pretrained_dict.keys()))
        #print({k for k, v in pretrained_dict.items() if k not in model_dict})
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        #print(len(pretrained_dict.keys()))
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        model = _construct_depth_model(model)
    if args.use_cuda:
        model = model.cuda()
    return model
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = "3"  ## todo

    with open('config/config.json', 'r') as f:
        cfg = json.load(f)
    batch_size = int(cfg['batch'])
    num_epochs = int(cfg['epochs'])
    num_classes = int(cfg['class_number'])
    shape = (int(cfg['height']), int(cfg['width']), 3)
    learning_rate = cfg['learning_rate']
    pre_weights = cfg['weights']
    train_dir = cfg['train_dir']
    eval_dir = cfg['eval_dir']
    test_dir = cfg['test_dir']
    model = cfg['model']
    #选择模型 (3D-MobileNetV2,3D-MobileNetV2+LSTM)
    if model == '3D-MobileNetV2':
        model = mobilenetv2.get_model(num_classes=num_classes,
                                      sample_size=shape[0],
                                      width_mult=1.0)
    if model == '3D-MobileNetV2-LSTM':
        model = mobilenetv2_lstm.get_model(num_classes=num_classes,
                                           sample_size=shape[0],
                                           width_mult=1.0)
    model = model.cuda()

    #load weights if it has
    if pre_weights and os.path.exists(pre_weights):
        weights = torch.load(pre_weights)
        model.load_state_dict(weights)
    '''
    model_path = r'./logs/fit/1/jester_mobilenetv2_1.0x_RGB_16_best.pth'
    checkpoint = torch.load(model_path)
    #load part of the weight because of finetune
    pretrained_dict = checkpoint['state_dict']
    model_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(pretrained_dict)
    '''
    '''
    train_dataloader = DataLoader(SequentialDataset(root_path=train_dir,images_len=10,rescale=1/255.), batch_size=batch_size,shuffle=True, num_workers=4)
    val_dataloader = DataLoader(SequentialDataset(root_path=eval_dir,images_len=10,rescale=1/255.), batch_size=batch_size, num_workers=4)
    trainval_loaders = {'train': train_dataloader, 'val': val_dataloader}
    trainval_sizes = {x: len(trainval_loaders[x].dataset) for x in ['train', 'val']}
    test_size = len(test_dataloader.dataset)
    
    image_dir = {'train': train_dir,
           'val': eval_dir}
    dataloaders_dict = {
        x: DataLoader(SequentialDataset(root_path=image_dir[x],images_len=40,rescale=1/255.), batch_size=batch_size, shuffle=True, num_workers=4) for x in
        ['train', 'val']}
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    cudnn.benchmark = True

    # Send the model to GPU
    model = model.to(device)

    params_to_update = model.parameters()

    #print("Params to learn:")
    #for name, param in model.named_parameters():
    #    if param.requires_grad == True:
    #        print("\t", name)

    # Observe that all parameters are being optimized
    #optimizer = optim.SGD(params_to_update, lr=learning_rate, momentum=0.9,weight_decay=5e-4)
    optimizer = optim.Adam(params_to_update,lr=learning_rate)

    # Setup the loss fxn
    criterion = torch.nn.CrossEntropyLoss()

    # Train and evaluate
    log_dir = "logs/fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")
    model, hist = train_model(log_dir, model, dataloaders_dict, criterion, optimizer, num_epochs=num_epochs)
    path = log_dir + '/model_final.pkl'
    torch.save(model.state_dict(), path)
    '''
    #test the result
    running_loss = 0.0
    running_corrects = 0.0
    test_dataloader = DataLoader(SequentialDataset(root_path=test_dir,
                                                   images_len=40,
                                                   rescale=1 / 255.),
                                 batch_size=batch_size,
                                 num_workers=4)

    test_size = len(test_dataloader.dataset)
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    for inputs, labels in test_dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            outputs = model(inputs)
        probs = torch.nn.Softmax(dim=1)(outputs)
        preds = torch.max(probs, 1)[1]
        running_corrects += torch.sum(preds == labels.data)

    print(test_size)
    print(running_corrects)
    epoch_acc = running_corrects / test_size
    print(epoch_acc)
Exemple #5
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet',
        'mobilenet', 'mobilenetv2'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)

    elif opt.model == 'mobilenet':
        from models.mobilenet import get_fine_tuning_parameters
        model = mobilenet.get_model(num_classes=opt.n_classes,
                                    sample_size=opt.sample_size,
                                    width_mult=opt.width_mult)
    elif opt.model == 'mobilenetv2':
        from models.mobilenetv2 import get_fine_tuning_parameters
        model = mobilenetv2.get_model(num_classes=opt.n_classes,
                                      sample_size=opt.sample_size,
                                      width_mult=opt.width_mult)

    if not opt.no_cuda:
        if not opt.no_cuda_predict:
            model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            print("Pretrain arch", pretrain['arch'])
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])
            ft_begin_index = opt.ft_begin_index
            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.9),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
                model.module.classifier = model.module.classifier.cuda()
                ft_begin_index = 'complete' if ft_begin_index == 0 else 'last_layer'
            elif opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features,
                    opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()
            print("Finetuning at:", ft_begin_index)
            parameters = get_fine_tuning_parameters(model, ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])
            ft_begin_index = opt.ft_begin_index
            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.9),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
                model.module.classifier = model.module.classifier.cuda()
                ft_begin_index = 'complete' if ft_begin_index == 0 else 'last_layer'
            elif opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)
            print("Finetuning at:", ft_begin_index)
            parameters = get_fine_tuning_parameters(model, ft_begin_index)
            return model, parameters

    return model, model.parameters()
Exemple #6
0
def get_fine_tuned(id):

    global opt
    opt = parse_opts()

    # Opt parameters
    opt.confidence_threshold = 0.1
    opt.frame_to_send = 60
    opt.downsample = 2
    opt.width_mult = 1.0
    opt.pretrain_path = f'./PreTrained/jester_mobilenetv2_{opt.width_mult}x_RGB_16_best.pth'
    opt.n_classes = 27
    opt.sample_size = 112
    opt.scale_in_test = 1.0
    opt.no_cuda = False
    opt.model = "mobilenetv2"
    opt.arch = "mobilenetv2"
    #opt.batch_size = 8
    opt.no_train = False
    opt.n_threads = 0
    opt.ft_portion = 'last_layer'
    #opt.n_epochs = 10
    #opt.learning_rate = 0.01
    opt.use_amp = True

    opt.scales = [opt.initial_scale]
    for i in range(1, opt.n_scales):
        opt.scales.append(opt.scales[-1] * opt.scale_step)

    opt.train_directory = './TrainData'
    training_data, validation_data, opt.n_finetune_classes = getData(
        id, opt.train_directory + str(id))
    #training_data, validation_data = data[:int(len(data)*0.8)], data[int(len(data)*0.8)+1:]
    print(f'Received Data for {opt.n_finetune_classes} classes')
    print(
        f'Train Size: {len(training_data)}, Validation Size: {len(validation_data)}'
    )

    # Model
    model = mobilenetv2.get_model(num_classes=opt.n_classes,
                                  sample_size=opt.sample_size,
                                  width_mult=opt.width_mult)

    # Model
    print('Initializing model')
    model, parameters = intialize(model)
    print('Model Intialized')

    #hyperparams
    opt.n_epochs = 5
    opt.batch_size = 8
    opt.learning_rate = 0.01

    #report = []

    opt.top1_values = []
    opt.top3_values = []
    opt.train_loss_values = []
    opt.top1_values_val = []
    opt.top3_values_val = []
    opt.val_loss_values = []

    begin_time = time.time()
    print("starting training")
    model = train(model, parameters, training_data, validation_data)
    del training_data
    del validation_data
    torch.cuda.empty_cache()
    total_time = time.time() - begin_time
    print(f"Training finished in {total_time}")
    """
    df_top1 = pd.DataFrame(opt.top1_values)
    df_top1.to_csv(f'./Graph_data/top1_{lr}_{bsize}.csv',header=False,index=False)
    df_top3 = pd.DataFrame(opt.top3_values)
    df_top3.to_csv(f'./Graph_data/top3_{lr}_{bsize}.csv',header=False,index=False)
    df_train_loss = pd.DataFrame(opt.train_loss_values)
    df_train_loss.to_csv(f'./Graph_data/train_loss_{lr}_{bsize}.csv',header=False,index=False)
    
    df_top1_val = pd.DataFrame(opt.top1_values_val)
    df_top1_val.to_csv(f'./Graph_data/top1_val_{lr}_{bsize}.csv',header=False,index=False)
    df_top3_val = pd.DataFrame(opt.top3_values_val)
    df_top3_val.to_csv(f'./Graph_data/top3_val_{lr}_{bsize}.csv',header=False,index=False)
    df_val_loss = pd.DataFrame(opt.val_loss_values)
    df_val_loss.to_csv(f'./Graph_data/val_loss_{lr}_{bsize}.csv',header=False,index=False)
    
    print('lr:',lr,'batch_size:',bsize,'train_acc:',opt.top1_values[-1],'val_acc:',opt.top1_values_val[-1],'time:',total_time)

    report.append({'lr':lr,'batch_size':bsize,'train_acc':opt.top1_values,'val_acc':opt.top1_values_val,'time':total_time})
    
    print('Report:')
    print(report)
    """
    return model
Exemple #7
0
        opt.scales.append(opt.scales[-1] * opt.scale_step)

    opt.mean = get_mean(opt.norm_value, dataset=opt.mean_dataset)
    opt.std = get_std(opt.norm_value)

    # Normalization (NOT USED)
    if opt.no_mean_norm and not opt.std_norm:
        norm_method = Normalize([0, 0, 0], [1, 1, 1])
    elif not opt.std_norm:
        norm_method = Normalize(opt.mean, [1, 1, 1])
    else:
        norm_method = Normalize(opt.mean, opt.std)

    # Model definition
    model = mobilenetv2.get_model(num_classes=opt.n_classes,
                                  sample_size=opt.sample_size,
                                  width_mult=opt.width_mult)

    # Model Initialization
    intialize(model)
    model.eval()

    save_cntr = 0

    # define a video capture object
    vid = cv2.VideoCapture(0)
    list_of_frames = []
    cnt = 0
    while (True):

        # Capture the video frame
Exemple #8
0
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import pdb

from utils.utils import AverageMeter, calculate_accuracy
from models import mobilenet, mobilenetv2

# model = shufflenet.get_model(groups=3, width_mult=0.5, num_classes=600)#1
# model = shufflenetv2.get_model( width_mult=0.25, num_classes=600, sample_size = 112)#2
# model = mobilenet.get_model( width_mult=0.5, num_classes=600, sample_size = 112)#3
model = mobilenetv2.get_model(width_mult=0.2, num_classes=600,
                              sample_size=112)  #4
# model = shufflenet.get_model(groups=3, width_mult=1.0, num_classes=600)#5
# model = shufflenetv2.get_model( width_mult=1.0, num_classes=600, sample_size = 112)#6
# model = mobilenet.get_model( width_mult=1.0, num_classes=600, sample_size = 112)#7
# model = mobilenetv2.get_model( width_mult=0.45, num_classes=600, sample_size = 112)#8
# model = shufflenet.get_model(groups=3, width_mult=1.5, num_classes=600)#9
# model = shufflenetv2.get_model( width_mult=1.5, num_classes=600, sample_size = 112)#10
# model = mobilenet.get_model( width_mult=1.5, num_classes=600, sample_size = 112)#11
# model = mobilenetv2.get_model( width_mult=0.7, num_classes=600, sample_size = 112)#12
# model = shufflenet.get_model(groups=3, width_mult=2.0, num_classes=600)#13
# model = shufflenetv2.get_model( width_mult=2.0, num_classes=600, sample_size = 112)#14
# model = mobilenet.get_model( width_mult=2.0, num_classes=600, sample_size = 112)#15
# model = mobilenetv2.get_model( width_mult=1.0, num_classes=600, sample_size = 112)#16
# model = squeezenet.get_model( version=1.1, num_classes=600, sample_size = 112, sample_duration = 8)
# model = resnet.resnet18(sample_size = 112, sample_duration = 8, num_classes=600)
# model = resnet.resnet50(sample_size = 112, sample_duration = 8, num_classes=600)
Exemple #9
0
def generate_model(opt):
    assert opt.model in [
        'c3d', 'squeezenet', 'mobilenet', 'resnext', 'resnet', 'resnetl',
        'shufflenet', 'mobilenetv2', 'shufflenetv2'
    ]

    if opt.model == 'c3d':
        from models.c3d import get_fine_tuning_parameters
        model = c3d.get_model(num_classes=opt.n_classes,
                              sample_size=opt.sample_size,
                              sample_duration=opt.sample_duration)
    elif opt.model == 'squeezenet':
        from models.squeezenet import get_fine_tuning_parameters
        model = squeezenet.get_model(version=opt.version,
                                     num_classes=opt.n_classes,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'shufflenet':
        from models.shufflenet import get_fine_tuning_parameters
        model = shufflenet.get_model(groups=opt.groups,
                                     width_mult=opt.width_mult,
                                     num_classes=opt.n_classes)
    elif opt.model == 'shufflenetv2':
        from models.shufflenetv2 import get_fine_tuning_parameters
        model = shufflenetv2.get_model(num_classes=opt.n_classes,
                                       sample_size=opt.sample_size,
                                       width_mult=opt.width_mult)
    elif opt.model == 'mobilenet':
        from models.mobilenet import get_fine_tuning_parameters
        model = mobilenet.get_model(num_classes=opt.n_classes,
                                    sample_size=opt.sample_size,
                                    width_mult=opt.width_mult)
    elif opt.model == 'mobilenetv2':
        from models.mobilenetv2 import get_fine_tuning_parameters
        model = mobilenetv2.get_model(num_classes=opt.n_classes,
                                      sample_size=opt.sample_size,
                                      width_mult=opt.width_mult)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]
        from models.resnext import get_fine_tuning_parameters
        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'resnetl':
        assert opt.model_depth in [10]

        from models.resnetl import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnetl.resnetl10(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]
        from models.resnet import get_fine_tuning_parameters
        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)
        pytorch_total_params = sum(p.numel() for p in model.parameters()
                                   if p.requires_grad)
        print("Total number of trainable parameters: ", pytorch_total_params)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path,
                                  map_location=torch.device('cpu'))
            # print(opt.arch)
            # print(pretrain['arch'])
            # assert opt.arch == pretrain['arch']
            model.load_state_dict(pretrain['state_dict'])

            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.5),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
                model.module.classifier = model.module.classifier.cuda()
            elif opt.model == 'squeezenet':
                model.module.classifier = nn.Sequential(
                    nn.Dropout(p=0.5),
                    nn.Conv3d(model.module.classifier[1].in_channels,
                              opt.n_finetune_classes,
                              kernel_size=1), nn.ReLU(inplace=True),
                    nn.AvgPool3d((1, 4, 4), stride=1))
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()
        if opt.modality == 'RGB' and opt.model != 'c3d':
            print("[INFO]: RGB model is used for init model")
            model = _modify_first_conv_layer(
                model, 3, 3)  ##### Check models trained (3,7,7) or (7,7,7)
        elif opt.modality == 'Depth':
            print(
                "[INFO]: Converting the pretrained model to Depth init model")
            model = _construct_depth_model(model)
            print("[INFO]: Done. Flow model ready.")
        elif opt.modality == 'RGB-D':
            print(
                "[INFO]: Converting the pretrained model to RGB+D init model")
            model = _construct_rgbdepth_model(model)
            print("[INFO]: Done. RGB-D model ready.")

        modules = list(model.modules())
        first_conv_idx = list(
            filter(lambda x: isinstance(modules[x], nn.Conv3d),
                   list(range(len(modules)))))[0]
        conv_layer = modules[first_conv_idx]
        if conv_layer.kernel_size[0] > opt.sample_duration:
            model = _modify_first_conv_layer(model,
                                             int(opt.sample_duration / 2), 1)

        parameters = get_fine_tuning_parameters(model, opt.ft_portion)
        return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']
            model.load_state_dict(pretrain['state_dict'])

            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.9),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
            elif opt.model == 'squeezenet':
                model.module.classifier = nn.Sequential(
                    nn.Dropout(p=0.5),
                    nn.Conv3d(model.module.classifier[1].in_channels,
                              opt.n_finetune_classes,
                              kernel_size=1), nn.ReLU(inplace=True),
                    nn.AvgPool3d((1, 4, 4), stride=1))
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)

        if opt.modality == 'RGB' and opt.model != 'c3d':
            print("[INFO]: RGB model is used for init model")
            model = _modify_first_conv_layer(model, 3, 3)
        elif opt.modality == 'Depth':
            print(
                "[INFO]: Converting the pretrained model to Depth init model")
            model = _construct_depth_model(model)
            print("[INFO]: Deoth model ready.")
        elif opt.modality == 'RGB-D':
            print(
                "[INFO]: Converting the pretrained model to RGB-D init model")
            model = _construct_rgbdepth_model(model)
            print("[INFO]: Done. RGB-D model ready.")

        modules = list(model.modules())
        first_conv_idx = list(
            filter(lambda x: isinstance(modules[x], nn.Conv3d),
                   list(range(len(modules)))))[0]
        conv_layer = modules[first_conv_idx]
        if conv_layer.kernel_size[0] > opt.sample_duration:
            print("[INFO]: RGB model is used for init model")
            model = _modify_first_conv_layer(model,
                                             int(opt.sample_duration / 2), 1)

        if opt.model == 'c3d':  # CHECK HERE
            model.fc = nn.Linear(model.fc[0].in_features,
                                 model.fc[0].out_features)
        else:
            model.fc = nn.Linear(model.fc.in_features, opt.n_finetune_classes)

        parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
        return model, parameters

    return model, model.parameters()
Exemple #10
0
def generate_model(opt):
    assert opt.model in [
        'c3d', 'squeezenet', 'mobilenet', 'shufflenet', 'mobilenetv2',
        'shufflenetv2'
    ]

    if opt.model == 'c3d':
        from models.c3d import get_fine_tuning_parameters
        model = c3d.get_model(num_classes=opt.n_classes,
                              sample_size=opt.sample_size,
                              sample_duration=opt.sample_duration)
    elif opt.model == 'squeezenet':
        from models.squeezenet import get_fine_tuning_parameters
        model = squeezenet.get_model(version=opt.version,
                                     num_classes=opt.n_classes,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'shufflenet':
        from models.shufflenet import get_fine_tuning_parameters
        model = shufflenet.get_model(groups=opt.groups,
                                     width_mult=opt.width_mult,
                                     num_classes=opt.n_classes)
    elif opt.model == 'shufflenetv2':
        from models.shufflenetv2 import get_fine_tuning_parameters
        model = shufflenetv2.get_model(num_classes=opt.n_classes,
                                       sample_size=opt.sample_size,
                                       width_mult=opt.width_mult)
    elif opt.model == 'mobilenet':
        from models.mobilenet import get_fine_tuning_parameters
        model = mobilenet.get_model(num_classes=opt.n_classes,
                                    sample_size=opt.sample_size,
                                    width_mult=opt.width_mult)
    elif opt.model == 'mobilenetv2':
        from models.mobilenetv2 import get_fine_tuning_parameters
        model = mobilenetv2.get_model(num_classes=opt.n_classes,
                                      sample_size=opt.sample_size,
                                      width_mult=opt.width_mult)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)
        pytorch_total_params = sum(p.numel() for p in model.parameters()
                                   if p.requires_grad)
        print("Total number of trainable parameters: ", pytorch_total_params)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path,
                                  map_location=torch.device('cpu'))
            model_dict = model.state_dict()
            pretrain_dict = {
                k: v
                for k, v in pretrain['state_dict'].items() if k in model_dict
            }
            model_dict.update(pretrain_dict)
            model.load_state_dict(model_dict)

            # if opt.model in  ['mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2']:
            #     model.module.classifier = nn.Sequential(
            #                     nn.Dropout(0.8),
            #                     nn.Linear(model.module.classifier[1].in_features, opt.n_finetune_classes))
            #     model.module.classifier = model.module.classifier.cuda()
            # elif opt.model == 'squeezenet':
            #     model.module.classifier = nn.Sequential(
            #                     nn.Dropout(p=0.5),
            #                     nn.Conv3d(model.module.classifier[1].in_channels, opt.n_finetune_classes, kernel_size=1),
            #                     nn.ReLU(inplace=True),
            #                     nn.AvgPool3d((1,4,4), stride=1))
            #     model.module.classifier = model.module.classifier.cuda()
            # else:
            #     model.module.fc = nn.Linear(model.module.fc.in_features, opt.n_finetune_classes)
            #     model.module.fc = model.module.fc.cuda()

            model = _modify_first_conv_layer(model)
            model = model.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_portion)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']
            model.load_state_dict(pretrain['state_dict'])

            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.9),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
            elif opt.model == 'squeezenet':
                model.module.classifier = nn.Sequential(
                    nn.Dropout(p=0.5),
                    nn.Conv3d(model.module.classifier[1].in_channels,
                              opt.n_finetune_classes,
                              kernel_size=1), nn.ReLU(inplace=True),
                    nn.AvgPool3d((1, 4, 4), stride=1))
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()