コード例 #1
0
def get_model(p, pretrain_path=None):
    # Get backbone
    if p['backbone'] == 'resnet18':
        if p['train_db_name'] in ['cifar-10', 'cifar-20']:
            from models.resnet_cifar import resnet18
            backbone = resnet18()

        elif p['train_db_name'] == 'stl-10':
            from models.resnet_stl import resnet18
            backbone = resnet18()

        else:
            raise NotImplementedError

    elif p['backbone'] == 'resnet50':
        if 'imagenet' in p['train_db_name']:
            from models.resnet import resnet50
            backbone = resnet50()

        else:
            raise NotImplementedError

    else:
        raise ValueError('Invalid backbone {}'.format(p['backbone']))

    # Setup
    if p['setup'] in ['simclr', 'moco']:
        from models.models import ContrastiveModel
        model = ContrastiveModel(backbone, **p['model_kwargs'])

    elif p['setup'] in ['scan', 'selflabel']:
        from models.models import ClusteringModel
        if p['setup'] == 'selflabel':
            assert (p['num_heads'] == 1)
        model = ClusteringModel(backbone, p['num_classes'], p['num_heads'])

    else:
        raise ValueError('Invalid setup {}'.format(p['setup']))

    # Load pretrained weights
    if pretrain_path is not None and os.path.exists(pretrain_path):
        state = torch.load(pretrain_path, map_location='cpu')

        if p['setup'] == 'scan':  # Weights are supposed to be transfered from contrastive training
            missing = model.load_state_dict(state, strict=False)
            assert (set(missing[1]) == {
                'contrastive_head.0.weight', 'contrastive_head.0.bias',
                'contrastive_head.2.weight', 'contrastive_head.2.bias'
            } or set(missing[1])
                    == {'contrastive_head.weight', 'contrastive_head.bias'})

        elif p['setup'] == 'selflabel':  # Weights are supposed to be transfered from scan
            # We only continue with the best head (pop all heads first, then copy back the best head)
            model_state = state['model']
            all_heads = [k for k in model_state.keys() if 'cluster_head' in k]
            best_head_weight = model_state['cluster_head.%d.weight' %
                                           (state['head'])]
            best_head_bias = model_state['cluster_head.%d.bias' %
                                         (state['head'])]
            for k in all_heads:
                model_state.pop(k)

            model_state['cluster_head.0.weight'] = best_head_weight
            model_state['cluster_head.0.bias'] = best_head_bias
            missing = model.load_state_dict(model_state, strict=True)

        else:
            raise NotImplementedError

    elif pretrain_path is not None and not os.path.exists(pretrain_path):
        raise ValueError(
            'Path with pre-trained weights does not exist {}'.format(
                pretrain_path))

    else:
        pass

    return model
コード例 #2
0
ファイル: model.py プロジェクト: wrn7777/MscProject
def generate_model(opt):
    assert opt.model in [
        'c3d', 'squeezenet', 'mobilenet', 'resnext', 'resnet', 'resnetl',
        'shufflenet', 'mobilenetv2', 'shufflenetv2'
    ]

    if opt.model == 'c3d':
        from models.c3d import get_fine_tuning_parameters
        model = c3d.get_model(num_classes=opt.n_classes,
                              sample_size=opt.sample_size,
                              sample_duration=opt.sample_duration)
    elif opt.model == 'squeezenet':
        from models.squeezenet import get_fine_tuning_parameters
        model = squeezenet.get_model(version=opt.version,
                                     num_classes=opt.n_classes,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'shufflenet':
        from models.shufflenet import get_fine_tuning_parameters
        model = shufflenet.get_model(groups=opt.groups,
                                     width_mult=opt.width_mult,
                                     num_classes=opt.n_classes)
    elif opt.model == 'shufflenetv2':
        from models.shufflenetv2 import get_fine_tuning_parameters
        model = shufflenetv2.get_model(num_classes=opt.n_classes,
                                       sample_size=opt.sample_size,
                                       width_mult=opt.width_mult)
    elif opt.model == 'mobilenet':
        from models.mobilenet import get_fine_tuning_parameters
        model = mobilenet.get_model(num_classes=opt.n_classes,
                                    sample_size=opt.sample_size,
                                    width_mult=opt.width_mult)
    elif opt.model == 'mobilenetv2':
        from models.mobilenetv2 import get_fine_tuning_parameters
        model = mobilenetv2.get_model(num_classes=opt.n_classes,
                                      sample_size=opt.sample_size,
                                      width_mult=opt.width_mult)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]
        from models.resnext import get_fine_tuning_parameters
        if opt.model_depth == 50:
            model = resnext.resnext50(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnext101(num_classes=opt.n_classes,
                                       shortcut_type=opt.resnet_shortcut,
                                       cardinality=opt.resnext_cardinality,
                                       sample_size=opt.sample_size,
                                       sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnext152(num_classes=opt.n_classes,
                                       shortcut_type=opt.resnet_shortcut,
                                       cardinality=opt.resnext_cardinality,
                                       sample_size=opt.sample_size,
                                       sample_duration=opt.sample_duration)
    elif opt.model == 'resnetl':
        assert opt.model_depth in [10]

        from models.resnetl import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnetl.resnetl10(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]
        from models.resnet import get_fine_tuning_parameters
        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)
        pytorch_total_params = sum(p.numel() for p in model.parameters()
                                   if p.requires_grad)
        # print("Total number of trainable parameters: ", pytorch_total_params)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path,
                                  map_location=torch.device('cpu'))
            # print(opt.arch)
            # print(pretrain['arch'])
            # assert opt.arch == pretrain['arch']
            model = modify_kernels(opt, model, opt.pretrain_modality)
            model.load_state_dict(pretrain['state_dict'])

            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.5),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
                model.module.classifier = model.module.classifier.cuda()
            elif opt.model == 'squeezenet':
                model.module.classifier = nn.Sequential(
                    nn.Dropout(p=0.5),
                    nn.Conv3d(model.module.classifier[1].in_channels,
                              opt.n_finetune_classes,
                              kernel_size=1), nn.ReLU(inplace=True),
                    nn.AvgPool3d((1, 4, 4), stride=1))
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            model = modify_kernels(opt, model, opt.modality)
        else:
            model = modify_kernels(opt, model, opt.modality)

        parameters = get_fine_tuning_parameters(model, opt.ft_portion)
    else:
        model = nn.DataParallel(model, device_ids=None)
        pytorch_total_params = sum(p.numel() for p in model.parameters()
                                   if p.requires_grad)
        # print("Total number of trainable parameters: ", pytorch_total_params)
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path,
                                  map_location=torch.device('cpu'))

            model = modify_kernels(opt, model, opt.pretrain_modality)
            model.load_state_dict(pretrain['state_dict'])

            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.5),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
            elif opt.model == 'squeezenet':
                model.module.classifier = nn.Sequential(
                    nn.Dropout(p=0.5),
                    nn.Conv3d(model.module.classifier[1].in_channels,
                              opt.n_finetune_classes,
                              kernel_size=1), nn.ReLU(inplace=True),
                    nn.AvgPool3d((1, 4, 4), stride=1))
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)

            model = modify_kernels(opt, model, opt.modality)
            parameters = get_fine_tuning_parameters(model, opt.ft_portion)
            return model, parameters
        else:
            model = modify_kernels(opt, model, opt.modality)

    return model, model.parameters()
コード例 #3
0
def generate_model(opt, phase):
    if phase == 'segment':
        assert opt.seg_model in ['deeplab']
        if opt.seg_model == 'deeplab':
            model = deeplab.Net(in_channel=opt.in_channel,
                                num_classes=opt.n_classes)
    elif phase == 'classify':
        assert opt.cla_model in [
            'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
        ]

        if opt.cla_model == 'resnet':
            assert opt.cla_model_depth in [10, 18, 34, 50, 101, 152, 200]

            from models.resnet import get_fine_tuning_parameters

            if opt.cla_model_depth == 10:
                model = resnet.resnet10(num_classes=opt.n_classes,
                                        shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 18:
                model = resnet.resnet18(num_classes=opt.n_classes,
                                        shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 34:
                model = resnet.resnet34(num_classes=opt.n_classes,
                                        shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 50:
                model = resnet.resnet50(num_classes=opt.n_classes,
                                        shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 101:
                model = resnet.resnet101(num_classes=opt.n_classes,
                                         shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 152:
                model = resnet.resnet152(num_classes=opt.n_classes,
                                         shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 200:
                model = resnet.resnet200(num_classes=opt.n_classes,
                                         shortcut_type=opt.cla_resnet_shortcut)
        elif opt.cla_model == 'wideresnet':
            assert opt.cla_model_depth in [50]

            from models.wide_resnet import get_fine_tuning_parameters

            if opt.cla_model_depth == 50:
                model = wide_resnet.resnet50(
                    num_classes=opt.n_classes,
                    shortcut_type=opt.cla_resnet_shortcut,
                    k=opt.wide_resnet_k)
        elif opt.cla_model == 'resnext':
            assert opt.cla_model_depth in [50, 101, 152]

            from models.resnext import get_fine_tuning_parameters

            if opt.cla_model_depth == 50:
                model = resnext.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.cla_resnet_shortcut,
                                         cardinality=opt.resnext_cardinality)
            elif opt.cla_model_depth == 101:
                model = resnext.resnet101(
                    num_classes=opt.n_classes,
                    shortcut_type=opt.cla_resnet_shortcut,
                    cardinality=opt.resnext_cardinality)
            elif opt.cla_model_depth == 152:
                model = resnext.resnet152(
                    num_classes=opt.n_classes,
                    shortcut_type=opt.cla_resnet_shortcut,
                    cardinality=opt.resnext_cardinality)
        elif opt.cla_model == 'preresnet':
            assert opt.cla_model_depth in [18, 34, 50, 101, 152, 200]

            from models.pre_act_resnet import get_fine_tuning_parameters

            if opt.cla_model_depth == 18:
                model = pre_act_resnet.resnet18(
                    num_classes=opt.n_classes,
                    shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 34:
                model = pre_act_resnet.resnet34(
                    num_classes=opt.n_classes,
                    shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 50:
                model = pre_act_resnet.resnet50(
                    num_classes=opt.n_classes,
                    shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 101:
                model = pre_act_resnet.resnet101(
                    num_classes=opt.n_classes,
                    shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 152:
                model = pre_act_resnet.resnet152(
                    num_classes=opt.n_classes,
                    shortcut_type=opt.cla_resnet_shortcut)
            elif opt.cla_model_depth == 200:
                model = pre_act_resnet.resnet200(
                    num_classes=opt.n_classes,
                    shortcut_type=opt.cla_resnet_shortcut)
        elif opt.cla_model == 'densenet':
            assert opt.cla_model_depth in [121, 169, 201, 264]

            from models.densenet import get_fine_tuning_parameters

            if opt.cla_model_depth == 121:
                model = densenet.densenet121(num_classes=opt.n_classes)
            elif opt.cla_model_depth == 169:
                model = densenet.densenet169(num_classes=opt.n_classes)
            elif opt.cla_model_depth == 201:
                model = densenet.densenet201(num_classes=opt.n_classes)
            elif opt.cla_model_depth == 264:
                model = densenet.densenet264(num_classes=opt.n_classes)
    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)
    return model, model.parameters()
コード例 #4
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    t_stride=opt.t_stride)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    t_stride=opt.t_stride)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    t_stride=opt.t_stride)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration,
                                    t_stride=opt.t_stride)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     t_stride=opt.t_stride)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     t_stride=opt.t_stride)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration,
                                     t_stride=opt.t_stride)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            #             if opt.model == 'densenet':
            #                 model.module.classifier = nn.Linear(
            #                     model.module.classifier.in_features, opt.n_finetune_classes)
            #                 model.module.classifier = model.module.classifier.cuda()
            #             else:
            #                 model.module.fc = nn.Linear(model.module.fc.in_features,
            #                                             opt.n_finetune_classes)
            #                 model.module.fc = model.module.fc.cuda()
            #
            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
コード例 #5
0
ファイル: model.py プロジェクト: zymale/MedicalNet
def generate_model(opt):
    assert opt.model in ['resnet']

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        if opt.model_depth == 10:
            model = resnet.resnet10(sample_input_W=opt.input_W,
                                    sample_input_H=opt.input_H,
                                    sample_input_D=opt.input_D,
                                    shortcut_type=opt.resnet_shortcut,
                                    no_cuda=opt.no_cuda,
                                    num_seg_classes=opt.n_seg_classes)
        elif opt.model_depth == 18:
            model = resnet.resnet18(sample_input_W=opt.input_W,
                                    sample_input_H=opt.input_H,
                                    sample_input_D=opt.input_D,
                                    shortcut_type=opt.resnet_shortcut,
                                    no_cuda=opt.no_cuda,
                                    num_seg_classes=opt.n_seg_classes)
        elif opt.model_depth == 34:
            model = resnet.resnet34(sample_input_W=opt.input_W,
                                    sample_input_H=opt.input_H,
                                    sample_input_D=opt.input_D,
                                    shortcut_type=opt.resnet_shortcut,
                                    no_cuda=opt.no_cuda,
                                    num_seg_classes=opt.n_seg_classes)
        elif opt.model_depth == 50:
            model = resnet.resnet50(sample_input_W=opt.input_W,
                                    sample_input_H=opt.input_H,
                                    sample_input_D=opt.input_D,
                                    shortcut_type=opt.resnet_shortcut,
                                    no_cuda=opt.no_cuda,
                                    num_seg_classes=opt.n_seg_classes)
        elif opt.model_depth == 101:
            model = resnet.resnet101(sample_input_W=opt.input_W,
                                     sample_input_H=opt.input_H,
                                     sample_input_D=opt.input_D,
                                     shortcut_type=opt.resnet_shortcut,
                                     no_cuda=opt.no_cuda,
                                     num_seg_classes=opt.n_seg_classes)
        elif opt.model_depth == 152:
            model = resnet.resnet152(sample_input_W=opt.input_W,
                                     sample_input_H=opt.input_H,
                                     sample_input_D=opt.input_D,
                                     shortcut_type=opt.resnet_shortcut,
                                     no_cuda=opt.no_cuda,
                                     num_seg_classes=opt.n_seg_classes)
        elif opt.model_depth == 200:
            model = resnet.resnet200(sample_input_W=opt.input_W,
                                     sample_input_H=opt.input_H,
                                     sample_input_D=opt.input_D,
                                     shortcut_type=opt.resnet_shortcut,
                                     no_cuda=opt.no_cuda,
                                     num_seg_classes=opt.n_seg_classes)

    if not opt.no_cuda:
        if len(opt.gpu_id) > 1:
            model = model.cuda()
            model = nn.DataParallel(model, device_ids=opt.gpu_id)
            net_dict = model.state_dict()
        else:
            import os
            os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.gpu_id[0])
            model = model.cuda()
            model = nn.DataParallel(model, device_ids=None)
            net_dict = model.state_dict()
    else:
        net_dict = model.state_dict()

    # load pretrain
    if opt.pretrain_path:
        print('loading pretrained model {}'.format(opt.pretrain_path))
        pretrain = torch.load(opt.pretrain_path)
        pretrain_dict = {
            k: v
            for k, v in pretrain['state_dict'].items() if k in net_dict.keys()
        }

        net_dict.update(pretrain_dict)
        model.load_state_dict(net_dict)

        new_parameters = []
        for pname, p in model.named_parameters():
            for layer_name in opt.new_layer_names:
                if pname.find(layer_name) >= 0:
                    new_parameters.append(p)
                    break

        new_parameters_id = list(map(id, new_parameters))
        base_parameters = list(
            filter(lambda p: id(p) not in new_parameters_id,
                   model.parameters()))
        parameters = {
            'base_parameters': base_parameters,
            'new_parameters': new_parameters
        }

        return model, parameters

    return model, model.parameters()
コード例 #6
0
 },
 'dpn107': {
     'filters': [128, 376, 1152, 2432, 2688],
     'init_op':
     dpn107,
     'last_upsample':
     64,
     'decoder_filters': [64, 128, 256, 256],
     'url':
     'http://data.lip6.fr/cadene/pretrainedmodels/dpn107_extra-1ac7121e2.pth'
 },
 'resnet50': {
     'filters': [64, 256, 512, 1024, 2048],
     'decoder_filters': [64, 128, 256, 256],
     'last_upsample': 64,
     'init_op': resnet.resnet50(),
     'url': resnet.model_urls['resnet50']
 },
 'densenet121': {
     'filters': [64, 256, 512, 1024, 1024],
     'decoder_filters': [64, 128, 256, 256],
     'last_upsample': 64,
     'url': None,
     'init_op': densenet121
 },
 'densenet169': {
     'filters': [64, 256, 512, 1280, 1664],
     'decoder_filters': [64, 128, 256, 256],
     'last_upsample': 64,
     'url': None,
     'init_op': densenet169
コード例 #7
0
ファイル: main_mgd.py プロジェクト: unsky/mgd
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1, best_acc5
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    # create model
    if args.arch == 'mobilenet_v1':
        from models import resnet, mobilenetv1
        t_net = resnet.resnet50(pretrained=True)
        s_net = mobilenetv1.mobilenet_v1()
        ignore_inds = []
    elif args.arch == 'mobilenet_v2':
        from models import resnet, mobilenetv2
        t_net = resnet.resnet50(pretrained=True)
        s_net = mobilenetv2.mobilenet_v2(pretrained=bool(args.use_pretrained))
        ignore_inds = [0]
    elif args.arch == 'resnet50':
        from models import resnet
        t_net = resnet.resnet152(pretrained=True)
        s_net = resnet.resnet50(pretrained=bool(args.use_pretrained))
        ignore_inds = []
    elif args.arch == 'shufflenet_v2':
        from models import resnet, shufflenetv2
        t_net = resnet.resnet50(pretrained=True)
        s_net = shufflenetv2.shufflenet_v2_x1_0(
            pretrained=bool(args.use_pretrained))
        ignore_inds = [0]
    else:
        raise ValueError

    if args.distiller == 'mgd':
        # normal and special reducers
        norm_reducers = ['amp', 'rd', 'sp']
        spec_reducers = ['sm']
        assert args.mgd_reducer in norm_reducers + spec_reducers

        # create distiller
        distiller = mgd.builder.MGDistiller if args.mgd_reducer in norm_reducers \
               else mgd.builder.SMDistiller

        d_net = distiller(t_net,
                          s_net,
                          ignore_inds=ignore_inds,
                          reducer=args.mgd_reducer,
                          sync_bn=args.sync_bn,
                          with_kd=args.mgd_with_kd,
                          preReLU=True,
                          distributed=args.distributed)
    else:
        raise NotImplementedError

    # model size
    print('the number of teacher model parameters: {}'.format(
        sum([p.data.nelement() for p in t_net.parameters()])))
    print('the number of student model parameters: {}'.format(
        sum([p.data.nelement() for p in s_net.parameters()])))
    print('the total number of model parameters: {}'.format(
        sum([p.data.nelement() for p in d_net.parameters()])))

    # dp convert
    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
    elif args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            t_net.cuda(args.gpu)
            s_net.cuda(args.gpu)
            d_net.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            if args.sync_bn:
                s_net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(s_net)
            t_net = torch.nn.parallel.DistributedDataParallel(
                t_net, find_unused_parameters=True, device_ids=[args.gpu])
            s_net = torch.nn.parallel.DistributedDataParallel(
                s_net, find_unused_parameters=True, device_ids=[args.gpu])
            d_net = torch.nn.parallel.DistributedDataParallel(
                d_net, find_unused_parameters=True, device_ids=[args.gpu])
        else:
            t_net.cuda()
            s_net.cuda()
            d_net.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            if args.sync_bn:
                s_net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(s_net)
            t_net = torch.nn.parallel.DistributedDataParallel(
                t_net, find_unused_parameters=True)
            s_net = torch.nn.parallel.DistributedDataParallel(
                s_net, find_unused_parameters=True)
            d_net = torch.nn.parallel.DistributedDataParallel(
                d_net, find_unused_parameters=True)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        t_net = t_net.cuda(args.gpu)
        s_net = s_net.cuda(args.gpu)
        d_net = d_net.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        t_net = torch.nn.DataParallel(t_net).cuda()
        s_net = torch.nn.DataParallel(s_net).cuda()
        d_net = torch.nn.DataParallel(d_net).cuda()

    # define loss function (criterion), optimizer and lr_scheduler
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    model_params = list(s_net.parameters()) + list(
        d_net.module.BNs.parameters())
    optimizer = torch.optim.SGD(model_params,
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)
    # warmup setting
    if args.warmup:
        args.epochs += args.warmup_epochs
        args.lr_drop_epochs = list(
            np.array(args.lr_drop_epochs) + args.warmup_epochs)
    lr_scheduler = build_lr_scheduler(optimizer, args)

    # optionally resume from a checkpoint
    load_checkpoint(t_net, args.teacher_resume, args)
    load_checkpoint(s_net, args.student_resume, args)

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    validdir = os.path.join(args.data, 'valid')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    valid_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        validdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True)

    if args.distributed:
        extra_sampler = mgd.sampler.ExtraDistributedSampler(train_dataset)
    else:
        extra_sampler = None

    extra_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=(extra_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=extra_sampler)

    print('=> evaluate teacher model')
    validate(valid_loader, t_net, criterion, args)
    print('=> evaluate student model')
    validate(valid_loader, s_net, criterion, args)
    if args.evaluate:
        return

    if args.distiller == 'mgd':
        mgd_update(extra_loader, d_net, args)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, d_net, criterion, optimizer, lr_scheduler, epoch,
              args)

        # evaluate on validation set
        acc1, acc5 = validate(valid_loader, s_net, criterion, args)

        # update flow matrix for the next round
        if args.distiller == 'mgd' and (epoch + 1) % args.mgd_update_freq == 0:
            mgd_update(extra_loader, d_net, args)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        best_acc5 = acc5 if is_best else best_acc5

        print(' * - Best - Err@1 {acc1:.3f} Err@5 {acc5:.3f}'.format(
            acc1=(100 - best_acc1), acc5=(100 - best_acc5)))

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            filename = '{}.pth'.format(args.arch)
            save_checkpoint(
                args, {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': s_net.state_dict(),
                    'best_acc1': best_acc1,
                    'best_acc5': acc5,
                    'optimizer': optimizer.state_dict(),
                }, is_best, filename)
        lr_scheduler.step()
        gc.collect()
コード例 #8
0
    raise "only support dataset CIFAR10 or CIFAR100"

if args.model == "lenet":
    net = lenet.LeNet(num_classes=num_classes)

elif args.model == "vgg16":
    net = vgg.vgg16(num_classes=num_classes, pretrained=args.pretrain)
elif args.model == "vgg16_bn":
    net = vgg.vgg16_bn(num_classes=num_classes, pretrained=args.pretrain)

elif args.model == "resnet18":
    net = resnet.resnet18(num_classes=num_classes, pretrained=args.pretrain)
elif args.model == "resnet34":
    net = resnet.resnet18(num_classes=num_classes, pretrained=args.pretrain)
elif args.model == "resnet50":
    net = resnet.resnet50(num_classes=num_classes, pretrained=args.pretrain)

elif args.model == "resnetv2_18":
    net = resnet_v2.resnet18(num_classes=num_classes, pretrained=args.pretrain)
elif args.model == "resnetv2_34":
    net = resnet_v2.resnet18(num_classes=num_classes, pretrained=args.pretrain)
elif args.model == "resnetv2_50":
    net = resnet_v2.resnet50(num_classes=num_classes, pretrained=args.pretrain)

elif args.model == "resnext50_32x4d":
    net = resnextto.resnext50_32x4d(num_classes=num_classes,
                                    pretrained=args.pretrain)
elif args.model == "resnext101_32x8d":
    net = resnextto.resnext101_32x8d(num_classes=num_classes,
                                     pretrained=args.pretrain)
コード例 #9
0
def define_model(is_resnet,
                 is_densenet,
                 is_senet,
                 model='tbdp',
                 parallel=False,
                 semff=False,
                 pcamff=False):
    if is_resnet:
        original_model = resnet.resnet50(pretrained=True)
        Encoder = modules.E_resnet(original_model)
        if model == 'tbdp':
            model = net.TBDPNet(Encoder,
                                num_features=2048,
                                block_channel=[256, 512, 1024, 2048],
                                parallel=parallel,
                                pcamff=pcamff)
        elif model == 'hu':
            model = net.Hu(Encoder,
                           num_features=2048,
                           block_channel=[256, 512, 1024, 2048],
                           semff=semff,
                           pcamff=pcamff)
        else:
            raise NotImplementedError(
                "Select model type in [\'tbdp\', \'hu\']")
    if is_densenet:
        original_model = densenet.densenet161(pretrained=True)
        Encoder = modules.E_densenet(original_model)
        if model == 'tbdp':
            model = net.TBDPNet(Encoder,
                                num_features=2208,
                                block_channel=[192, 384, 1056, 2208],
                                parallel=parallel,
                                pcamff=pcamff)
        elif model == 'hu':
            model = net.Hu(Encoder,
                           num_features=2208,
                           block_channel=[192, 384, 1056, 2208],
                           semff=semff,
                           pcamff=pcamff)
        else:
            raise NotImplementedError(
                "Select model type in [\'tbdp\', \'hu\']")
    if is_senet:
        original_model = senet.senet154(pretrained='imagenet')
        Encoder = modules.E_senet(original_model)
        if model == 'tbdp':
            model = net.TBDPNet(Encoder,
                                num_features=2048,
                                block_channel=[256, 512, 1024, 2048],
                                parallel=parallel,
                                pcamff=pcamff)
        elif model == 'hu':
            model = net.Hu(Encoder,
                           num_features=2048,
                           block_channel=[256, 512, 1024, 2048],
                           semff=semff,
                           pcamff=pcamff)
        else:
            raise NotImplementedError(
                "Select model type in [\'tbdp\', \'hu\']")

    return model
コード例 #10
0
def build_model(args):

    if args.arch == 'iresgroup':
        assert args.model_depth in [50, 101, 152]

        if args.model_depth == 50:
            model = iresgroup.iresgroup50(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)
        elif args.model_depth == 101:
            model = iresgroup.iresgroup101(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)
        elif args.model_depth == 152:
            model = iresgroup.iresgroup152(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)

    if args.arch == 'iresgroupfix':
        assert args.model_depth in [50, 101, 152]

        if args.model_depth == 50:
            model = iresgroupfix.iresgroupfix50(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)
        elif args.model_depth == 101:
            model = iresgroupfix.iresgroupfix101(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)
        elif args.model_depth == 152:
            model = iresgroupfix.iresgroupfix152(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)

    if args.arch == 'resgroupfix':
        assert args.model_depth in [50, 101, 152]

        if args.model_depth == 50:
            model = resgroupfix.resgroupfix50(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)
        elif args.model_depth == 101:
            model = resgroupfix.resgroupfix101(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)
        elif args.model_depth == 152:
            model = resgroupfix.resgroupfix152(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)

    if args.arch == 'resgroup':
        assert args.model_depth in [50, 101, 152]

        if args.model_depth == 50:
            model = resgroup.resgroup50(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)
        elif args.model_depth == 101:
            model = resgroup.resgroup101(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)
        elif args.model_depth == 152:
            model = resgroup.resgroup152(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual,
                groups=args.groups)

    if args.arch == 'iresnet':
        assert args.model_depth in [18, 34, 50, 101, 152, 200, 302, 404, 1001]

        if args.model_depth == 18:
            model = iresnet.iresnet18(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 34:
            model = iresnet.iresnet34(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 50:
            model = iresnet.iresnet50(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 101:
            model = iresnet.iresnet101(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 152:
            model = iresnet.iresnet152(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 200:
            model = iresnet.iresnet200(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 302:
            model = iresnet.iresnet302(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 404:
            model = iresnet.iresnet404(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 1001:
            model = iresnet.iresnet1001(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)

    if args.arch == 'resstage':
        assert args.model_depth in [18, 34, 50, 101, 152, 200]

        if args.model_depth == 18:
            model = resstage.resstage18(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 34:
            model = resstage.resstage34(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 50:
            model = resstage.resstage50(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 101:
            model = resstage.resstage101(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 152:
            model = resstage.resstage152(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 200:
            model = resstage.resstage200(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)

    if args.arch == 'resnet':
        assert args.model_depth in [18, 34, 50, 101, 152, 200]

        if args.model_depth == 18:
            model = resnet.resnet18(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 34:
            model = resnet.resnet34(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 50:
            model = resnet.resnet50(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 101:
            model = resnet.resnet101(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 152:
            model = resnet.resnet152(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)
        elif args.model_depth == 200:
            model = resnet.resnet200(
                pretrained=args.pretrained,
                num_classes=args.n_classes,
                zero_init_residual=args.zero_init_residual)

    return model
コード例 #11
0
ファイル: train.py プロジェクト: iceicei/MSSC
def main(args):
    visenv_name = args.dataset
    exp_dir = os.path.join('exp_result', args.dataset)
    model_dir, log_dir = get_model_log_path(exp_dir, visenv_name)
    stdout_file = os.path.join(log_dir, f'stdout_{time_str()}.txt')
    save_model_path = os.path.join(model_dir, 'ckpt_max.pth')

    if args.redirector:
        print('redirector stdout')
        ReDirectSTD(stdout_file, 'stdout', False)

    pprint.pprint(OrderedDict(args.__dict__))

    print('-' * 60)
    print(f'use GPU{args.device} for training')
    print(
        f'train set: {args.dataset} {args.train_split}, test set: {args.valid_split}'
    )

    train_tsfm, valid_tsfm = get_transform(args)
    print(train_tsfm)

    train_set = AttrDataset(args=args,
                            split=args.train_split,
                            transform=train_tsfm)

    train_loader = DataLoader(dataset=train_set,
                              batch_size=args.batchsize,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True,
                              drop_last=True)
    valid_set = AttrDataset(args=args,
                            split=args.valid_split,
                            transform=valid_tsfm)

    valid_loader = DataLoader(dataset=valid_set,
                              batch_size=args.batchsize,
                              shuffle=False,
                              num_workers=4,
                              pin_memory=True,
                              drop_last=True)

    print(f'{args.train_split} set: {len(train_loader.dataset)}, '
          f'{args.valid_split} set: {len(valid_loader.dataset)}, '
          f'attr_num : {train_set.attr_num}')

    labels = train_set.label
    sample_weight = labels.mean(0)

    backbone = resnet50()
    classifier = BaseClassifier(nattr=train_set.attr_num)

    model = MSSC(backbone, classifier)

    if torch.cuda.is_available():
        model = torch.nn.DataParallel(model).cuda()

    criterion = CEL_Sigmoid(sample_weight)

    param_groups = [{
        'params': model.module.finetune_params(),
        'lr': args.lr_ft
    }, {
        'params': model.module.fresh_params(),
        'lr': args.lr_new
    }]

    optimizer = torch.optim.SGD(param_groups,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=False)
    lr_scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=4)

    best_metric, epoch = trainer(epoch=args.train_epoch,
                                 model=model,
                                 train_loader=train_loader,
                                 valid_loader=valid_loader,
                                 criterion=criterion,
                                 optimizer=optimizer,
                                 lr_scheduler=lr_scheduler,
                                 path=save_model_path,
                                 dataset=args.dataset)

    print(f'{visenv_name},  best_metrc : {best_metric} in epoch{epoch}')
コード例 #12
0
ファイル: resnet.py プロジェクト: theodorblackbird/smia
def run():
    torch.multiprocessing.freeze_support()
    import torchgeometry
    mean = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
    std = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
    # chargement du modele resnet50

    device = torch.device("cuda")
    net = resnet50().to(device)
    net.load_state_dict(torch.load("resnet50-90-regular.pth"))

    normalize = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    cifar100_test = torchvision.datasets.CIFAR100(root='./data',
                                                  train=False,
                                                  download=True,
                                                  transform=normalize)
    cifat100_test_loader = DataLoader(cifar100_test,
                                      shuffle=True,
                                      num_workers=4,
                                      batch_size=BATCH_SIZE)

    n_samples = 0
    n_well_classified = 0
    n_well_classified_top5 = 0
    net.eval()

    well_classified_images = torch.tensor([]).to(device)
    well_classified_targets = torch.tensor([]).to(device)

    with torch.no_grad():
        for i, batch in enumerate(cifat100_test_loader):
            if i > 100:
                break
            image, target = batch
            image = image.to(device)
            target = target.to(device)
            n_samples += image.shape[0]

            prob = net(image).argmax(1)

            correct = (prob == target)
            n_well_classified += correct.sum()

            well_classified_images = torch.cat(
                (well_classified_images, image[correct]), 0)
            well_classified_targets = torch.cat(
                (well_classified_targets, target[correct]), 0)

    well_DS = wellImageDataset(well_classified_images.cpu(),
                               well_classified_targets.cpu())
    well_DL = DataLoader(well_DS, shuffle=False, num_workers=1, batch_size=10)

    misclassified_images = torch.tensor([]).to(device)
    misclassified_images_perb = torch.tensor([]).to(device)
    misclassified_targets = torch.tensor([]).to(device)
    n_misc = 0
    for i, (batch_image, batch_target) in enumerate(well_DL):
        batch_image = batch_image.cuda()
        batch_target = batch_target.cuda()
        per = attack.smia_attack(batch_image, 0.005, 0, net,
                                 batch_target.long(), 10)
        prob = net(per).argmax(1)
        correct = prob == batch_target
        k_well_class = (prob == batch_target).sum()
        n_misc += k_well_class

        misclassified_images = torch.cat(
            (misclassified_images, batch_image[~correct]), 0)
        misclassified_images_perb = torch.cat(
            (misclassified_images_perb, per[~correct]), 0)
        misclassified_targets = torch.cat(
            (misclassified_targets, batch_target[~correct]), 0)
        if i > 10:
            break
        print(i)
        print(k_well_class)
    print(n_misc / len(well_classified_images))

    example = denormalize_image(misclassified_images[1].unsqueeze(0))
    example_perb = denormalize_image(misclassified_images_perb[1].unsqueeze(0))
    example = example.detach().cpu().squeeze()
    example_perb = example_perb.detach().cpu().squeeze()
    plt.imshow(example)
    plt.show()
    plt.imshow(example_perb)
    plt.show()
コード例 #13
0
ファイル: resnet2d_int.py プロジェクト: asw91666/gcnet_test
    def __init__(self, num_landmarks):
        super(ResNet50Int_global, self).__init__()

        self.num_landmarks = num_landmarks

        # Load pretrained resnet-50 model
        # pretrained = models.resnet50(pretrained=True)
        #
        # # Remove last 2 layers
        # modules = list(pretrained.children())[:-2]

        # load resnet with gcnet
        self.encoder = resnet.resnet50(pretrained=True)
        modules = nn.ModuleList([])

        # Add 1st deconv block (stride = 16)
        modules.append(
            nn.Sequential(
                nn.ConvTranspose2d(2048,
                                   256,
                                   kernel_size=4,
                                   stride=2,
                                   padding=1,
                                   bias=False), nn.BatchNorm2d(256),
                nn.ReLU(inplace=True)))

        # Add 2nd deconv block (stride = 32)
        modules.append(
            nn.Sequential(
                nn.ConvTranspose2d(256,
                                   256,
                                   kernel_size=4,
                                   stride=2,
                                   padding=1,
                                   bias=False), nn.BatchNorm2d(256),
                nn.ReLU(inplace=True)))

        # Add 3rd deconv block (stride = 64)
        modules.append(
            nn.Sequential(
                nn.ConvTranspose2d(256,
                                   256,
                                   kernel_size=4,
                                   stride=2,
                                   padding=1,
                                   bias=False), nn.BatchNorm2d(256),
                nn.ReLU(inplace=True)))

        # Add regression layer
        modules.append(nn.Conv2d(256, num_landmarks, 1))

        self.module = nn.ModuleList(modules)

        # For integration
        self.relu = nn.ReLU(inplace=True)
        self.register_buffer('wx', torch.arange(64.0) * 4.0 + 2.0)
        self.wx = self.wx.reshape(1, 64).repeat(64, 1).reshape(64 * 64, 1)
        self.register_buffer('wy', torch.arange(64.0) * 4.0 + 2.0)
        self.wy = self.wy.reshape(64, 1).repeat(1, 64).reshape(64 * 64, 1)

        self.fliptest = False
コード例 #14
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet',
        'mobilenet', 'mobilenetv2'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes,
                                         shortcut_type=opt.resnet_shortcut,
                                         k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes,
                                     shortcut_type=opt.resnet_shortcut,
                                     cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size,
                                     sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size,
                                         sample_duration=opt.sample_duration)

    elif opt.model == 'mobilenet':
        from models.mobilenet import get_fine_tuning_parameters
        model = mobilenet.get_model(num_classes=opt.n_classes,
                                    sample_size=opt.sample_size,
                                    width_mult=opt.width_mult)
    elif opt.model == 'mobilenetv2':
        from models.mobilenetv2 import get_fine_tuning_parameters
        model = mobilenetv2.get_model(num_classes=opt.n_classes,
                                      sample_size=opt.sample_size,
                                      width_mult=opt.width_mult)

    if not opt.no_cuda:
        if not opt.no_cuda_predict:
            model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            print("Pretrain arch", pretrain['arch'])
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])
            ft_begin_index = opt.ft_begin_index
            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.9),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
                model.module.classifier = model.module.classifier.cuda()
                ft_begin_index = 'complete' if ft_begin_index == 0 else 'last_layer'
            elif opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features,
                    opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()
            print("Finetuning at:", ft_begin_index)
            parameters = get_fine_tuning_parameters(model, ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])
            ft_begin_index = opt.ft_begin_index
            if opt.model in [
                    'mobilenet', 'mobilenetv2', 'shufflenet', 'shufflenetv2'
            ]:
                model.module.classifier = nn.Sequential(
                    nn.Dropout(0.9),
                    nn.Linear(model.module.classifier[1].in_features,
                              opt.n_finetune_classes))
                model.module.classifier = model.module.classifier.cuda()
                ft_begin_index = 'complete' if ft_begin_index == 0 else 'last_layer'
            elif opt.model == 'densenet':
                model.classifier = nn.Linear(model.classifier.in_features,
                                             opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                     opt.n_finetune_classes)
            print("Finetuning at:", ft_begin_index)
            parameters = get_fine_tuning_parameters(model, ft_begin_index)
            return model, parameters

    return model, model.parameters()
コード例 #15
0
ファイル: model.py プロジェクト: y6216886/darklight_concat
def generate_model(opt):
    assert opt.model in [
        'resnet3D', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    if opt.model == 'resnet3D':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                k=opt.wide_resnet_k,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        # model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features, opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain:
            print('loading pretrained model {}'.format(opt.pretrain))
            pretrain = torch.load(opt.pretrain)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(
                    model.classifier.in_features, opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                            opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
















# import torch.nn as nn
# import math
# import numpy as np
# import torch
# from torch.autograd import Variable
# from torch.nn.parameter import Parameter
# import random
# import torch.utils.model_zoo as model_zoo
#
# def initweights(m):
#     orthogonal_flag = False
#     for layer in m.modules():
#         if isinstance(layer, nn.Conv2d):
#             n = layer.kernel_size[0] * layer.kernel_size[1] * layer.out_channels
#             layer.weight.data.normal_(0, math.sqrt(2. / n))
#
#
#
#             # orthogonal initialize
#             """Reference:
#             [1] Saxe, Andrew M., James L. McClelland, and Surya Ganguli.
#                "Exact solutions to the nonlinear dynamics of learning in deep
#                linear neural networks." arXiv preprint arXiv:1312.6120 (2013)."""
#             if orthogonal_flag:
#                 weight_shape = layer.weight.data.cpu().numpy().shape
#                 u, _, v = np.linalg.svd(layer.weight.data.cpu().numpy(), full_matrices=False)
#                 flat_shape = (weight_shape[0], np.prod(weight_shape[1:]))
#                 q = u if u.shape == flat_shape else v
#                 q = q.reshape(weight_shape)
#                 layer.weight.data.copy_(torch.Tensor(q))
#
#         elif isinstance(layer, nn.BatchNorm2d):
#             layer.weight.data.fill_(1)
#             layer.bias.data.zero_()
#         elif isinstance(layer, nn.Linear):
#             layer.bias.data.zero_()
#
# def conv3x3(in_planes, out_planes, stride=1):
#     """3x3 convolution with padding"""
#     return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
#                      padding=1, bias=False)
#
# class BasicBlock(nn.Module):
#     expansion = 1
#
#     def __init__(self, inplanes, planes, stride=1, downsample=None):
#         super(BasicBlock, self).__init__()
#         self.conv1 = conv3x3(inplanes, planes, stride)
#         self.bn1 = nn.BatchNorm2d(planes)
#         self.relu = nn.ReLU(inplace=True)
#         self.conv2 = conv3x3(planes, planes)
#         self.bn2 = nn.BatchNorm2d(planes)
#         self.downsample = downsample
#         self.stride = stride
#
#     def forward(self, x):
#         residual = x
#
#         out = self.conv1(x)
#         out = self.bn1(out)
#         out = self.relu(out)
#
#         out = self.conv2(out)
#         out = self.bn2(out)
#
#         if self.downsample is not None:
#             residual = self.downsample(x)
#
#         out += residual
#         out = self.relu(out)
#
#         return out
#
# class Bottleneck(nn.Module):
#     expansion = 4
#
#     def __init__(self, inplanes, planes, stride=1, downsample=None):
#         super(Bottleneck, self).__init__()
#         self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
#         self.bn1 = nn.BatchNorm2d(planes)
#         self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
#                                padding=1, bias=False)
#         self.bn2 = nn.BatchNorm2d(planes)
#         self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
#         self.bn3 = nn.BatchNorm2d(planes * 4)
#         self.relu = nn.ReLU(inplace=True)
#         self.downsample = downsample
#         self.stride = stride
#
#     def forward(self, x):
#         residual = x
#
#         out = self.conv1(x)
#         out = self.bn1(out)
#         out = self.relu(out)
#
#         out = self.conv2(out)
#         out = self.bn2(out)
#         out = self.relu(out)
#
#         out = self.conv3(out)
#         out = self.bn3(out)
#
#         if self.downsample is not None:
#             residual = self.downsample(x)
#
#         out += residual
#         out = self.relu(out)
#
#         return out
#
#
# class ResNet(nn.Module):
#
#     def __init__(self, block, layers, num_classes=1000):
#         self.inplanes = 64
#         super(ResNet, self).__init__()
#         self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
#                                bias=False)
#         self.bn1 = nn.BatchNorm2d(64)
#         self.relu = nn.ReLU(inplace=True)
#         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
#         self.layer1 = self._make_layer(block, 64, layers[0])
#         self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
#         self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
#         self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
#         self.avgpool = nn.AvgPool2d(7, stride=1)
#         self.fc = nn.Linear(512 * block.expansion, num_classes)
#
#         self.apply(initweights)
#
#     def _make_layer(self, block, planes, blocks, stride=1):
#         downsample = None
#         if stride != 1 or self.inplanes != planes * block.expansion:
#             downsample = nn.Sequential(
#                 nn.Conv2d(self.inplanes, planes * block.expansion,
#                           kernel_size=1, stride=stride, bias=False),
#                 nn.BatchNorm2d(planes * block.expansion),
#             )
#
#         layers = []
#         layers.append(block(self.inplanes, planes, stride, downsample))
#         self.inplanes = planes * block.expansion
#         for i in range(1, blocks):
#             layers.append(block(self.inplanes, planes))
#
#         return nn.Sequential(*layers)
#
#     def forward(self, x):
#         x = self.conv1(x)
#         x = self.bn1(x)
#         x = self.relu(x)
#         x = self.maxpool(x)
#
#         x = self.layer1(x)
#         x = self.layer2(x)
#         x = self.layer3(x)
#         x = self.layer4(x)
#
#         x = self.avgpool(x)
#         x = x.view(x.size(0), -1)
#         x = self.fc(x)
#
#         return x
#
# class ResNetImageNet(nn.Module):
#
#     def __init__(self, opt, num_classes=1000):
#         super(ResNetImageNet, self).__init__()
#         self.opt = opt
#         self.depth = opt.depth
#         self.num_classes = num_classes
#         if self.depth == 18:
#         	self.model = ResNet(BasicBlock, [2, 2, 2, 2], self.num_classes)
#         elif self.depth == 34:
#         	self.model = ResNet(BasicBlock, [3, 4, 6, 3], self.num_classes)
#         elif self.depth == 50:
#         	self.model = ResNet(Bottleneck, [3, 4, 6, 3], self.num_classes)
#         elif self.depth == 101:
#         	self.model = ResNet(Bottleneck, [3, 4, 23, 3], self.num_classes)
#         elif self.depth == 152:
#         	self.model = ResNet(Bottleneck, [3, 8, 36, 3], self.num_classes)
#
#
#     def forward(self, x):
#         x = self.model(x)
#         return x
#
#
#
# # def resnet18(pretrained=False, **kwargs):
# #     """Constructs a ResNet-18 model.
# #     Args:
# #         pretrained (bool): If True, returns a model pre-trained on ImageNet
# #     """
# #     model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
# #     if pretrained:
# #         model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
# #     return model
コード例 #16
0
    else:
        dataset_val = dataset
    model = create_model(
        opt)  # create a model given opt.model and other options
    model.setup(
        opt)  # regular setup: load and print networks; create schedulers
    # create a website
    #web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch))  # define the website directory
    #webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))
    # test with eval mode. This only affects layers like batchnorm and dropout.
    # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode.
    # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.
    if opt.eval:
        model.eval()

    vggface2_model = resnet50(num_classes=8631, include_top=False)
    with open('pretrain_model/resnet50_ft_weight.pkl', 'rb') as f:
        obj = f.read()
    weights = {
        key: torch.from_numpy(arr)
        for key, arr in pickle.loads(obj, encoding='latin1').items()
    }
    vggface2_model.load_state_dict(weights)
    vggface2_model.to(opt.gpu_ids[0])
    vggface2_model = torch.nn.DataParallel(vggface2_model, opt.gpu_ids)
    vggface2_model.eval()

    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[FID_DIMS]
    fid_model = InceptionV3([block_idx], normalize_input=False)
    fid_model.to(opt.gpu_ids[0])
    fid_mode = torch.nn.DataParallel(fid_model, opt.gpu_ids)
コード例 #17
0
def main():

    global best_acc1
    best_acc1 = 0

    args = parse_option()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # set the data loader
    train_folder = os.path.join(args.data_folder, 'train')
    val_folder = os.path.join(args.data_folder, 'val')

    logger = getLogger(args.save_folder)
    if args.dataset.startswith('imagenet') or args.dataset.startswith(
            'places'):
        image_size = 224
        crop_padding = 32
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        normalize = transforms.Normalize(mean=mean, std=std)
        if args.aug == 'NULL':
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(image_size,
                                             scale=(args.crop, 1.)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        elif args.aug == 'CJ':
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(image_size,
                                             scale=(args.crop, 1.)),
                transforms.RandomGrayscale(p=0.2),
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            raise NotImplemented('augmentation not supported: {}'.format(
                args.aug))

        val_transform = transforms.Compose([
            transforms.Resize(image_size + crop_padding),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            normalize,
        ])
        if args.dataset.startswith('imagenet'):
            train_dataset = datasets.ImageFolder(train_folder, train_transform)
            val_dataset = datasets.ImageFolder(
                val_folder,
                val_transform,
            )

        if args.dataset.startswith('places'):
            train_dataset = ImageList(
                '/data/trainvalsplit_places205/train_places205.csv',
                '/data/data/vision/torralba/deeplearning/images256',
                transform=train_transform,
                symbol_split=' ')
            val_dataset = ImageList(
                '/data/trainvalsplit_places205/val_places205.csv',
                '/data/data/vision/torralba/deeplearning/images256',
                transform=val_transform,
                symbol_split=' ')

        print(len(train_dataset))
        train_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.n_workers,
            pin_memory=False,
            sampler=train_sampler)

        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.n_workers,
                                                 pin_memory=False)
    elif args.dataset.startswith('cifar'):
        train_loader, val_loader = cifar.get_linear_dataloader(args)
    elif args.dataset.startswith('svhn'):
        train_loader, val_loader = svhn.get_linear_dataloader(args)

    # create model and optimizer
    if args.model == 'alexnet':
        if args.layer == 6:
            args.layer = 5
        model = AlexNet(128)
        model = nn.DataParallel(model)
        classifier = LinearClassifierAlexNet(args.layer, args.n_label, 'avg')
    elif args.model == 'alexnet_cifar':
        if args.layer == 6:
            args.layer = 5
        model = AlexNet_cifar(128)
        model = nn.DataParallel(model)
        classifier = LinearClassifierAlexNet(args.layer,
                                             args.n_label,
                                             'avg',
                                             cifar=True)
    elif args.model == 'resnet50':
        model = resnet50(non_linear_head=False)
        model = nn.DataParallel(model)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 1)
    elif args.model == 'resnet18':
        model = resnet18()
        model = nn.DataParallel(model)
        classifier = LinearClassifierResNet(args.layer,
                                            args.n_label,
                                            'avg',
                                            1,
                                            bottleneck=False)
    elif args.model == 'resnet18_cifar':
        model = resnet18_cifar(256, non_linear_head=True, mlpbn=True)
        model = nn.DataParallel(model)
        classifier = LinearClassifierResNet(args.layer,
                                            args.n_label,
                                            'avg',
                                            1,
                                            bottleneck=False)
    elif args.model == 'resnet50_cifar':
        model = resnet50_cifar()
        model = nn.DataParallel(model)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 1)
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 2)
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 4)
    elif args.model == 'shufflenet':
        model = shufflenet_v2_x1_0(num_classes=128, non_linear_head=False)
        model = nn.DataParallel(model)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg',
                                            0.5)
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))

    print('==> loading pre-trained model')
    ckpt = torch.load(args.model_path)
    if not args.moco:
        model.load_state_dict(ckpt['state_dict'])
    else:
        try:
            state_dict = ckpt['state_dict']
            for k in list(state_dict.keys()):
                # retain only encoder_q up to before the embedding layer
                if k.startswith('module.encoder_q'
                                ) and not k.startswith('module.encoder_q.fc'):
                    # remove prefix
                    state_dict['module.' +
                               k[len("module.encoder_q."):]] = state_dict[k]
                # delete renamed or unused k
                del state_dict[k]
            model.load_state_dict(state_dict)
        except:
            pass
    print("==> loaded checkpoint '{}' (epoch {})".format(
        args.model_path, ckpt['epoch']))
    print('==> done')

    model = model.cuda()
    classifier = classifier.cuda()

    criterion = torch.nn.CrossEntropyLoss().cuda(args.gpu)

    if not args.adam:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(args.beta1, args.beta2),
                                     weight_decay=args.weight_decay,
                                     eps=1e-8)

    model.eval()
    cudnn.benchmark = True

    # optionally resume from a checkpoint
    args.start_epoch = 1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            # checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            classifier.load_state_dict(checkpoint['classifier'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_acc1 = checkpoint['best_acc1']
            print(best_acc1.item())
            best_acc1 = best_acc1.cuda()
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            if 'opt' in checkpoint.keys():
                # resume optimization hyper-parameters
                print('=> resume hyper parameters')
                if 'bn' in vars(checkpoint['opt']):
                    print('using bn: ', checkpoint['opt'].bn)
                if 'adam' in vars(checkpoint['opt']):
                    print('using adam: ', checkpoint['opt'].adam)
                #args.learning_rate = checkpoint['opt'].learning_rate
                # args.lr_decay_epochs = checkpoint['opt'].lr_decay_epochs
                args.lr_decay_rate = checkpoint['opt'].lr_decay_rate
                args.momentum = checkpoint['opt'].momentum
                args.weight_decay = checkpoint['opt'].weight_decay
                args.beta1 = checkpoint['opt'].beta1
                args.beta2 = checkpoint['opt'].beta2
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # tensorboard
    tblogger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    # routine
    best_acc = 0.0
    for epoch in range(args.start_epoch, args.epochs + 1):

        adjust_learning_rate(epoch, args, optimizer)
        print("==> training...")

        time1 = time.time()
        train_acc, train_acc5, train_loss = train(epoch, train_loader, model,
                                                  classifier, criterion,
                                                  optimizer, args)
        time2 = time.time()
        logging.info('train epoch {}, total time {:.2f}'.format(
            epoch, time2 - time1))

        logging.info(
            'Epoch: {}, lr:{} , train_loss: {:.4f}, train_acc: {:.4f}/{:.4f}'.
            format(epoch, optimizer.param_groups[0]['lr'], train_loss,
                   train_acc, train_acc5))

        tblogger.log_value('train_acc', train_acc, epoch)
        tblogger.log_value('train_acc5', train_acc5, epoch)
        tblogger.log_value('train_loss', train_loss, epoch)
        tblogger.log_value('learning_rate', optimizer.param_groups[0]['lr'],
                           epoch)

        test_acc, test_acc5, test_loss = validate(val_loader, model,
                                                  classifier, criterion, args)

        if test_acc >= best_acc:
            best_acc = test_acc

        logging.info(
            colorful(
                'Epoch: {}, val_loss: {:.4f}, val_acc: {:.4f}/{:.4f}, best_acc: {:.4f}'
                .format(epoch, test_loss, test_acc, test_acc5, best_acc)))
        tblogger.log_value('test_acc', test_acc, epoch)
        tblogger.log_value('test_acc5', test_acc5, epoch)
        tblogger.log_value('test_loss', test_loss, epoch)

        # save the best model
        if test_acc > best_acc1:
            best_acc1 = test_acc
            state = {
                'opt': args,
                'epoch': epoch,
                'classifier': classifier.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }
            save_name = '{}_layer{}.pth'.format(args.model, args.layer)
            save_name = os.path.join(args.save_folder, save_name)
            print('saving best model!')
            torch.save(state, save_name)

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'epoch': epoch,
                'classifier': classifier.state_dict(),
                'best_acc1': test_acc,
                'optimizer': optimizer.state_dict(),
            }
            save_name = 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)
            save_name = os.path.join(args.save_folder, save_name)
            print('saving regular model!')
            torch.save(state, save_name)

        # tensorboard logger
        pass
コード例 #18
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'resnetl', 'resnext', 'c3d', 'mobilenetv2', 'shufflenetv2',
        "mstcn"
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 50]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes,
                                    shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size,
                                    sample_duration=opt.sample_duration)
    elif opt.model == 'resnetl':
        assert opt.model_depth in [10, 18]

        from models.resnetl import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnetl.resnetl10(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnetl.resnetl10(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)

    elif opt.model == 'resnext':
        assert opt.model_depth in [101]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes,
                                      shortcut_type=opt.resnet_shortcut,
                                      cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size,
                                      sample_duration=opt.sample_duration)
    elif opt.model == 'c3d':
        assert opt.model_depth in [10]

        from models.c3d import get_fine_tuning_parameters
        if opt.model_depth == 10:

            model = c3d.c3d_v1(sample_size=opt.sample_size,
                               sample_duration=opt.sample_duration,
                               num_classes=opt.n_classes)
    elif opt.model == 'mobilenetv2':

        from models.mobilenetv2 import get_fine_tuning_parameters
        model = mobilenetv2.mob_v2(num_classes=opt.n_classes,
                                   sample_size=opt.sample_size,
                                   width_mult=opt.width_mult)
    elif opt.model == 'shufflenetv2':

        from models.shufflenetv2 import get_fine_tuning_parameters
        model = shufflenetv2.shf_v2(num_classes=opt.n_classes,
                                    sample_size=opt.sample_size,
                                    width_mult=opt.width_mult)

    elif opt.model == "mstcn":
        model = MultiStageTemporalConvNet(
            embed_size=opt.embedding_dim,
            encoder=opt.tcn_encoder,
            n_classes=opt.n_classes,
            input_size=(opt.sample_size, opt.sample_size),
            input_channels=4 if opt.modality != "RGB" else 3,
            num_stages=opt.tcn_stages,
            causal_config=opt.tcn_causality,
            CTHW_layout=True,
            use_preprocessing=opt.use_preprocessing)

    if not opt.no_cuda:

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']
            if opt.pretrain_dataset == 'jester':
                if opt.sample_duration < 32 and opt.model != 'c3d':
                    model = _modify_first_conv_layer(model, 3, 3)
                if opt.model in ['mobilenetv2', 'shufflenetv2']:
                    del pretrain['state_dict']['module.classifier.1.weight']
                    del pretrain['state_dict']['module.classifier.1.bias']
                else:
                    del pretrain['state_dict']['module.fc.weight']
                    del pretrain['state_dict']['module.fc.bias']
                model.load_state_dict(pretrain['state_dict'], strict=False)

        if opt.modality in ['RGB', 'flo'] and opt.model != 'c3d':
            print("[INFO]: RGB model is used for init model")
            if opt.dataset != 'jester' and not opt.no_first_lay:
                model = _modify_first_conv_layer(
                    model, 3, 3)  ##### Check models trained (3,7,7) or (7,7,7)
        elif opt.modality in ['Depth', 'seg']:
            print(
                "[INFO]: Converting the pretrained model to Depth init model")
            model = _construct_depth_model(model)
            print("[INFO]: Done. Flow model ready.")
        elif opt.modality in ['RGB-D', 'RGB-flo', 'RGB-seg']:
            if opt.model != "mstcn":
                print(
                    "[INFO]: Converting the pretrained model to RGB+D init model"
                )
                model = _construct_rgbdepth_model(model)
                if opt.no_first_lay:
                    model = _modify_first_conv_layer(
                        model, 3,
                        4)  ##### Check models trained (3,7,7) or (7,7,7)
                print("[INFO]: Done. RGB-D model ready.")
        if opt.pretrain_dataset == opt.dataset:
            model.load_state_dict(pretrain['state_dict'])
        elif opt.pretrain_dataset in ['egogesture', 'nv', 'denso']:
            del pretrain['state_dict']['module.fc.weight']
            del pretrain['state_dict']['module.fc.bias']
            model.load_state_dict(pretrain['state_dict'], strict=False)

        # Check first kernel size
        if opt.model != "mstcn":
            modules = list(model.modules())
            first_conv_idx = list(
                filter(lambda x: isinstance(modules[x], nn.Conv3d),
                       list(range(len(modules)))))[0]

            conv_layer = modules[first_conv_idx]
            if conv_layer.kernel_size[0] > opt.sample_duration:
                print("[INFO]: RGB model is used for init model")
                model = _modify_first_conv_layer(model,
                                                 int(opt.sample_duration / 2),
                                                 1)

        if opt.model == 'c3d':  # CHECK HERE
            model.module.fc = nn.Linear(model.module.fc[0].in_features,
                                        model.module.fc[0].out_features)
            model.module.fc = model.module.fc.cuda()
        elif opt.model in ['mobilenetv2', 'shufflenetv2']:
            model.module.classifier = nn.Sequential(
                nn.Dropout(0.9),
                nn.Linear(model.module.classifier[1].in_features,
                          opt.n_finetune_classes))
            model.module.classifier = model.module.classifier.cuda()
        elif opt.model != "mstcn":
            model.module.fc = nn.Linear(model.module.fc.in_features,
                                        opt.n_finetune_classes)
            model.module.fc = model.module.fc.cuda()

        if opt.model != "mstcn":
            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
        else:
            parameters = model.trainable_parameters()

        model = nn.DataParallel(model, device_ids=None)

        model = model.cuda()
        return model, parameters

    else:
        print('ERROR no cuda')

    return model, model.parameters()
コード例 #19
0
def get_network(args, use_gpu=True, num_train=0):
    """ return given network
    """
    if args.dataset == 'cifar-10':
        num_classes = 10
    elif args.dataset == 'cifar-100':
        num_classes = 100
    else:
        num_classes = 0

    if args.ignoring:
        if args.net == 'resnet18':
            from models.resnet_ign import resnet18_ign
            criterion = nn.CrossEntropyLoss(reduction='none')
            net = resnet18_ign(criterion, num_classes=num_classes, num_train=num_train,softmax=args.softmax,isalpha=args.isalpha)

    else:
        if args.net == 'vgg16':
            from models.vgg import vgg16_bn
            net = vgg16_bn()
        elif args.net == 'vgg13':
            from models.vgg import vgg13_bn
            net = vgg13_bn()
        elif args.net == 'vgg11':
            from models.vgg import vgg11_bn
            net = vgg11_bn()
        elif args.net == 'vgg19':
            from models.vgg import vgg19_bn
            net = vgg19_bn()
        elif args.net == 'densenet121':
            from models.densenet import densenet121
            net = densenet121()
        elif args.net == 'densenet161':
            from models.densenet import densenet161
            net = densenet161()
        elif args.net == 'densenet169':
            from models.densenet import densenet169
            net = densenet169()
        elif args.net == 'densenet201':
            from models.densenet import densenet201
            net = densenet201()
        elif args.net == 'googlenet':
            from models.googlenet import googlenet
            net = googlenet()
        elif args.net == 'inceptionv3':
            from models.inceptionv3 import inceptionv3
            net = inceptionv3()
        elif args.net == 'inceptionv4':
            from models.inceptionv4 import inceptionv4
            net = inceptionv4()
        elif args.net == 'inceptionresnetv2':
            from models.inceptionv4 import inception_resnet_v2
            net = inception_resnet_v2()
        elif args.net == 'xception':
            from models.xception import xception
            net = xception()
        elif args.net == 'resnet18':
            from models.resnet import resnet18
            net = resnet18(num_classes=num_classes)
        elif args.net == 'resnet34':
            from models.resnet import resnet34
            net = resnet34(num_classes=num_classes)
        elif args.net == 'resnet50':
            from models.resnet import resnet50
            net = resnet50(num_classes=num_classes)
        elif args.net == 'resnet101':
            from models.resnet import resnet101
            net = resnet101(num_classes=num_classes)
        elif args.net == 'resnet152':
            from models.resnet import resnet152
            net = resnet152(num_classes=num_classes)
        elif args.net == 'preactresnet18':
            from models.preactresnet import preactresnet18
            net = preactresnet18()
        elif args.net == 'preactresnet34':
            from models.preactresnet import preactresnet34
            net = preactresnet34()
        elif args.net == 'preactresnet50':
            from models.preactresnet import preactresnet50
            net = preactresnet50()
        elif args.net == 'preactresnet101':
            from models.preactresnet import preactresnet101
            net = preactresnet101()
        elif args.net == 'preactresnet152':
            from models.preactresnet import preactresnet152
            net = preactresnet152()
        elif args.net == 'resnext50':
            from models.resnext import resnext50
            net = resnext50()
        elif args.net == 'resnext101':
            from models.resnext import resnext101
            net = resnext101()
        elif args.net == 'resnext152':
            from models.resnext import resnext152
            net = resnext152()
        elif args.net == 'shufflenet':
            from models.shufflenet import shufflenet
            net = shufflenet()
        elif args.net == 'shufflenetv2':
            from models.shufflenetv2 import shufflenetv2
            net = shufflenetv2()
        elif args.net == 'squeezenet':
            from models.squeezenet import squeezenet
            net = squeezenet()
        elif args.net == 'mobilenet':
            from models.mobilenet import mobilenet
            net = mobilenet()
        elif args.net == 'mobilenetv2':
            from models.mobilenetv2 import mobilenetv2
            net = mobilenetv2()
        elif args.net == 'nasnet':
            from models.nasnet import nasnet
            net = nasnet()
        elif args.net == 'attention56':
            from models.attention import attention56
            net = attention56()
        elif args.net == 'attention92':
            from models.attention import attention92
            net = attention92()
        elif args.net == 'seresnet18':
            from models.senet import seresnet18
            net = seresnet18()
        elif args.net == 'seresnet34':
            from models.senet import seresnet34
            net = seresnet34()
        elif args.net == 'seresnet50':
            from models.senet import seresnet50
            net = seresnet50()
        elif args.net == 'seresnet101':
            from models.senet import seresnet101
            net = seresnet101()
        elif args.net == 'seresnet152':
            from models.senet import seresnet152
            net = seresnet152()

        else:
            print('the network name you have entered is not supported yet')
            sys.exit()

    if use_gpu:
        net = net.cuda()

    return net
コード例 #20
0
def generate_model(opt):
    assert opt.mode in ['score', 'feature']
    if opt.mode == 'score':
        last_fc = True
    elif opt.mode == 'feature':
        last_fc = False

    assert opt.model_name in ['resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet']

    if opt.model_name == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        if opt.model_depth == 10:
            model = resnet.resnet10(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.model_depth == 18:
            model = resnet.resnet18(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.model_depth == 34:
            model = resnet.resnet34(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.model_depth == 50:
            model = resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                    sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                    last_fc=last_fc)
        elif opt.model_depth == 101:
            model = resnet.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
        elif opt.model_depth == 152:
            model = resnet.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
        elif opt.model_depth == 200:
            model = resnet.resnet200(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                     sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
    elif opt.model_name == 'wideresnet':
        assert opt.model_depth in [50]

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, k=opt.wide_resnet_k,
                                         sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
    elif opt.model_name == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        if opt.model_depth == 50:
            model = resnext.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality,
                                     sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                     last_fc=last_fc)
        elif opt.model_depth == 101:
            model = resnext.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                      last_fc=last_fc, spatio_temporal=opt.spatio_temporal, temporal_only=opt.temporal_only)
        elif opt.model_depth == 152:
            model = resnext.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut, cardinality=opt.resnext_cardinality,
                                      sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                      last_fc=last_fc)
    elif opt.model_name == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                            sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                            last_fc=last_fc)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                            sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                            last_fc=last_fc)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                            sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                            last_fc=last_fc)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                             sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                             last_fc=last_fc)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                             sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                             last_fc=last_fc)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(num_classes=opt.n_classes, shortcut_type=opt.resnet_shortcut,
                                             sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                             last_fc=last_fc)
    elif opt.model_name == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        if opt.model_depth == 121:
            model = densenet.densenet121(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
        elif opt.model_depth == 169:
            model = densenet.densenet169(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
        elif opt.model_depth == 201:
            model = densenet.densenet201(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                         last_fc=last_fc)
        elif opt.model_depth == 264:
            model = densenet.densenet264(num_classes=opt.n_classes,
                                         sample_size=opt.sample_size, sample_duration=opt.sample_duration,
                                         last_fc=last_fc)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

    return model
コード例 #21
0
def create_model(model_name, pretrained):
    model = None
    if model_name == 'resnet50':
        model = resnet50(pretrained=pretrained, progress=True)
        if not args.dropout:
            model.fc = nn.Linear(2048, 40)
        else:
            model.fc = nn.Sequential(
                nn.Dropout(p=args.dropout_p),  # Dropout
                nn.Linear(2048, 40))
        return model
    elif model_name == 'resnext101_32x8d_wsl':
        model = resnext101_32x8d_wsl(pretrained=pretrained, progress=True)
        if not args.dropout:
            model.fc = nn.Linear(2048, 40)
        else:
            model.fc = nn.Sequential(
                nn.Dropout(p=args.dropout_p),  # Dropout
                nn.Linear(2048, 40))
        return model
    elif model_name == 'resnext101_32x16d_wsl':
        model = resnext101_32x16d_wsl(pretrained=pretrained, progress=True)
        if not args.dropout:
            model.fc = nn.Linear(2048, 40)
        else:
            model.fc = nn.Sequential(
                nn.Dropout(p=args.dropout_p),  # Dropout
                nn.Linear(2048, 40))
        return model
    elif model_name == 'resnext101_32x32d_wsl':
        model = resnext101_32x32d_wsl(pretrained=pretrained, progress=True)
        if not args.dropout:
            model.fc = nn.Linear(2048, 40)
        else:
            model.fc = nn.Sequential(
                nn.Dropout(p=args.dropout_p),  # Dropout
                nn.Linear(2048, 40))
        return model
    elif model_name == 'densenet121':
        model = densenet121(pretrained=pretrained, progress=True)
        model.classifier = nn.Linear(1024, 40)
        return model
    elif model_name == 'densenet169':
        model = densenet169(pretrained=pretrained, progress=True)
        model.classifier = nn.Linear(1664, 40)
        return model
    elif model_name == 'densenet201':
        model = densenet201(pretrained=pretrained, progress=True)
        model.classifier = nn.Linear(1920, 40)
        return model
    elif model_name == 'efficientnet_b7':
        if pretrained == True:
            model = efficientnet.from_pretrained('efficientnet-b7',
                                                 num_classes=40)
        elif pretrained == False:
            model = efficientnet.from_name('efficientnet-b7',
                                           override_params={'num_classes': 40})
        return model
    elif model_name == 'se_resnet':
        model = se_resnet101(pretrained=False)
        model_pth_path = '../model_pth/se_resnet101.pth.tar'
        checkpoint_state_dict = torch.load(model_pth_path,
                                           map_location='cpu')['state_dict']
        for layer_name in model.state_dict():
            checkpoint_state_dict[layer_name] = checkpoint_state_dict[
                'module.' + layer_name]
            del checkpoint_state_dict['module.' + layer_name]
        model.load_state_dict(checkpoint_state_dict)
        model.fc = nn.Linear(2048, 40)
        return model
    elif model_name == 'cbam_resnet':
        model = cbam_resnet101(pretrained=False)
        model_pth_path = '../model_pth/cbam_resnet101.pth.tar'
        checkpoint_state_dict = torch.load(model_pth_path,
                                           map_location='cpu')['state_dict']
        for layer_name in model.state_dict():
            checkpoint_state_dict[layer_name] = checkpoint_state_dict[
                'module.' + layer_name]
            del checkpoint_state_dict['module.' + layer_name]
        model.load_state_dict(checkpoint_state_dict)
        model.fc = nn.Linear(2048, 40)
コード例 #22
0
    args.data + '/train_images', transform=data_transforms),
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           num_workers=4)
val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
    args.data + '/val_images', transform=val_transforms),
                                         batch_size=args.batch_size // 2,
                                         shuffle=False,
                                         num_workers=4)

### Neural Network and Optimizer
# We define neural net in model.py so that it can be reused by the evaluate.py script
from model import Net
#model = Net(args.dp)
if 'resnet50' in args.net:
    model = resnet.resnet50(args.pretrained, dp=args.dp)
    #model = resnet.resnet50(True, dp = args.dp)
if 'resnet18' in args.net:
    model = resnet.resnet18(args.pretrained, dp=args.dp)

if args.load:
    model.load_state_dict(torch.load(args.load))
#model.load_state_dict(torch.load(['model_latest.pth'])['state_dict'])

optimizer = optim.SGD(model.parameters(),
                      lr=args.lr,
                      momentum=args.momentum,
                      weight_decay=args.weight_decay)
scheduler = optim.lr_scheduler.StepLR(optimizer, args.step)
device = torch.device('cuda:0')
model.to(device)
コード例 #23
0
def main():
    # load pretrained model
    checkpoint = torch.load(args.checkpoint_path)

    try:
        model_arch = checkpoint['model_name']
        patch_size = checkpoint['patch_size']
        prime_size = checkpoint['patch_size']
        flops = checkpoint['flops']
        model_flops = checkpoint['model_flops']
        policy_flops = checkpoint['policy_flops']
        fc_flops = checkpoint['fc_flops']
        anytime_classification = checkpoint['anytime_classification']
        budgeted_batch_classification = checkpoint[
            'budgeted_batch_classification']
        dynamic_threshold = checkpoint['dynamic_threshold']
        maximum_length = len(checkpoint['flops'])
    except:
        print(
            'Error: \n'
            'Please provide essential information'
            'for customized models (as we have done '
            'in pre-trained models)!\n'
            'At least the following information should be Given: \n'
            '--model_name: name of the backbone CNNs (e.g., resnet50, densenet121)\n'
            '--patch_size: size of image patches (i.e., H\' or W\' in the paper)\n'
            '--flops: a list containing the Multiply-Adds corresponding to each '
            'length of the input sequence during inference')

    model_configuration = model_configurations[model_arch]

    if args.eval_mode > 0:
        # create model
        if 'resnet' in model_arch:
            model = resnet.resnet50(pretrained=False)
            model_prime = resnet.resnet50(pretrained=False)

        elif 'densenet' in model_arch:
            model = eval('densenet.' + model_arch)(pretrained=False)
            model_prime = eval('densenet.' + model_arch)(pretrained=False)

        elif 'efficientnet' in model_arch:
            model = create_model(model_arch,
                                 pretrained=False,
                                 num_classes=1000,
                                 drop_rate=0.3,
                                 drop_connect_rate=0.2)
            model_prime = create_model(model_arch,
                                       pretrained=False,
                                       num_classes=1000,
                                       drop_rate=0.3,
                                       drop_connect_rate=0.2)

        elif 'mobilenetv3' in model_arch:
            model = create_model(model_arch,
                                 pretrained=False,
                                 num_classes=1000,
                                 drop_rate=0.2,
                                 drop_connect_rate=0.2)
            model_prime = create_model(model_arch,
                                       pretrained=False,
                                       num_classes=1000,
                                       drop_rate=0.2,
                                       drop_connect_rate=0.2)

        elif 'regnet' in model_arch:
            import pycls.core.model_builder as model_builder
            from pycls.core.config import cfg
            cfg.merge_from_file(model_configuration['cfg_file'])
            cfg.freeze()

            model = model_builder.build_model()
            model_prime = model_builder.build_model()

        traindir = args.data_url + 'train/'
        valdir = args.data_url + 'val/'

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_set = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(
                    model_configuration['image_size'],
                    interpolation=model_configuration['dataset_interpolation']
                ),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(), normalize
            ]))
        train_set_index = torch.randperm(len(train_set))
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=256,
            num_workers=32,
            pin_memory=False,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(
                train_set_index[-200000:]))

        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(
                    int(model_configuration['image_size'] /
                        model_configuration['crop_pct']),
                    interpolation=model_configuration['dataset_interpolation']
                ),
                transforms.CenterCrop(model_configuration['image_size']),
                transforms.ToTensor(), normalize
            ])),
                                                 batch_size=256,
                                                 shuffle=False,
                                                 num_workers=16,
                                                 pin_memory=False)

        state_dim = model_configuration['feature_map_channels'] * math.ceil(
            patch_size / 32) * math.ceil(patch_size / 32)

        memory = Memory()
        policy = ActorCritic(model_configuration['feature_map_channels'],
                             state_dim,
                             model_configuration['policy_hidden_dim'],
                             model_configuration['policy_conv'])
        fc = Full_layer(model_configuration['feature_num'],
                        model_configuration['fc_hidden_dim'],
                        model_configuration['fc_rnn'])

        model = nn.DataParallel(model.cuda())
        model_prime = nn.DataParallel(model_prime.cuda())
        policy = policy.cuda()
        fc = fc.cuda()

        model.load_state_dict(checkpoint['model_state_dict'])
        model_prime.load_state_dict(checkpoint['model_prime_state_dict'])
        fc.load_state_dict(checkpoint['fc'])
        policy.load_state_dict(checkpoint['policy'])

        budgeted_batch_flops_list = []
        budgeted_batch_acc_list = []

        print('generate logits on test samples...')
        test_logits, test_targets, anytime_classification = generate_logits(
            model_prime, model, fc, memory, policy, val_loader, maximum_length,
            prime_size, patch_size, model_arch)

        if args.eval_mode == 2:
            print('generate logits on training samples...')
            dynamic_threshold = torch.zeros([39, maximum_length])
            train_logits, train_targets, _ = generate_logits(
                model_prime, model, fc, memory, policy, train_loader,
                maximum_length, prime_size, patch_size, model_arch)

        for p in range(1, 40):

            print('inference: {}/40'.format(p))

            _p = torch.FloatTensor(1).fill_(p * 1.0 / 20)
            probs = torch.exp(torch.log(_p) * torch.range(1, maximum_length))
            probs /= probs.sum()

            if args.eval_mode == 2:
                dynamic_threshold[p - 1] = dynamic_find_threshold(
                    train_logits, train_targets, probs)

            acc_step, flops_step = dynamic_evaluate(test_logits, test_targets,
                                                    flops,
                                                    dynamic_threshold[p - 1])

            budgeted_batch_acc_list.append(acc_step)
            budgeted_batch_flops_list.append(flops_step)

        budgeted_batch_classification = [
            budgeted_batch_flops_list, budgeted_batch_acc_list
        ]

    print('model_arch :', model_arch)
    print('patch_size :', patch_size)
    print('flops :', flops)
    print('model_flops :', model_flops)
    print('policy_flops :', policy_flops)
    print('fc_flops :', fc_flops)
    print('anytime_classification :', anytime_classification)
    print('budgeted_batch_classification :', budgeted_batch_classification)
コード例 #24
0
def main():
    #Print args
    global args, best_prec1
    args = parser.parse_args()
    print(
        '\n\t\t\t\t Aum Sri Sai Ram\nFER on CKPLUS using Local and global Attention along with region branch (non-overlapping patches)\n\n'
    )
    print(args)
    print('\nimg_dir: ', args.root_path)
    print('\ntrain rule: ', args.train_rule, ' and loss type: ',
          args.loss_type, '\n')

    mean = [0.5, 0.5, 0.5]
    std = [0.5, 0.5, 0.5]
    imagesize = args.imagesize
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.4,
                               contrast=0.3,
                               saturation=0.25,
                               hue=0.05),
        transforms.Resize((args.imagesize, args.imagesize)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    valid_transform = transforms.Compose([
        transforms.Resize((args.imagesize, args.imagesize)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    # prepare model

    for fold in range(0, args.folds):
        print(
            '\n***************************************************************************************\n\nFold is ',
            fold)
        basemodel = resnet50(pretrained=False)
        attention_model = AttentionBranch(
            inputdim=512,
            num_regions=args.num_attentive_regions,
            num_classes=args.num_classes)
        region_model = RegionBranch(inputdim=1024,
                                    num_regions=args.num_regions,
                                    num_classes=args.num_classes)

        basemodel = torch.nn.DataParallel(basemodel).to(device)
        attention_model = torch.nn.DataParallel(attention_model).to(device)
        region_model = torch.nn.DataParallel(region_model).to(device)

        print('\nNumber of parameters:')
        print(
            'Base Model: {}, Attention Branch:{}, Region Branch:{} and Total: {}'
            .format(
                count_parameters(basemodel), count_parameters(attention_model),
                count_parameters(region_model),
                count_parameters(basemodel) +
                count_parameters(attention_model) +
                count_parameters(region_model)))

        best_prec1 = 0

        optimizer = torch.optim.SGD([{
            "params": basemodel.parameters(),
            "lr": 0.00001,
            "momentum": args.momentum,
            "weight_decay": args.weight_decay
        }])

        optimizer.add_param_group({
            "params": attention_model.parameters(),
            "lr": args.lr,
            "momentum": args.momentum,
            "weight_decay": args.weight_decay
        })

        optimizer.add_param_group({
            "params": region_model.parameters(),
            "lr": args.lr,
            "momentum": args.momentum,
            "weight_decay": args.weight_decay
        })
        '''      
        if args.pretrained:
           util.load_state_dict(basemodel,'pretrainedmodels/vgg_msceleb_resnet50_ft_weight.pkl')   
        '''
        if args.resume:
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume)
                basemodel.load_state_dict(checkpoint['base_state_dict'])
                attention_model.load_state_dict(
                    checkpoint['attention_state_dict'])
                region_model.load_state_dict(checkpoint['region_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        if fold > 0:
            args.valid_list = args.valid_list.replace(str(fold - 1), str(fold))
            args.train_list = args.train_list.replace(str(fold - 1), str(fold))
        print(args.train_list, args.valid_list)
        val_data = ImageList(root=args.root_path,
                             fileList=args.valid_list,
                             transform=valid_transform)

        val_loader = torch.utils.data.DataLoader(val_data,
                                                 args.batch_size,
                                                 shuffle=False,
                                                 num_workers=8)

        train_dataset = ImageList(root=args.root_path,
                                  fileList=args.train_list,
                                  transform=train_transform)
        cls_num_list = train_dataset.get_cls_num_list()

        if args.train_rule == 'None':
            train_sampler = None
            per_cls_weights = None
        elif args.train_rule == 'Resample':
            train_sampler = ImbalancedDatasetSampler(train_dataset)
            per_cls_weights = None
        elif args.train_rule == 'Reweight':
            train_sampler = None
            beta = 0.9999
            effective_num = 1.0 - np.power(beta, cls_num_list)
            per_cls_weights = (1.0 - beta) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights).to(device)

        if args.loss_type == 'CE':
            criterion = nn.CrossEntropyLoss(weight=per_cls_weights).to(device)
        elif args.loss_type == 'Focal':
            criterion = FocalLoss(weight=per_cls_weights, gamma=2).to(device)
        else:
            warnings.warn('Loss type is not listed')
            return

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.workers,
            pin_memory=True,
            sampler=train_sampler)

        print('length of train Database: ' + str(len(train_dataset)))
        print('length of test Database: ' + str(len(val_loader.dataset)))

        print('\nTraining starting: ', fold)

        for epoch in range(args.start_epoch, args.epochs):
            train(train_loader, basemodel, attention_model, region_model,
                  criterion, optimizer, epoch)
            #adjust_learning_rate(optimizer, epoch)
            prec1 = validate(val_loader, basemodel, attention_model,
                             region_model, criterion, epoch)
            print("Epoch: {}   Test Acc: {}".format(epoch, prec1))
            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1

            best_prec1 = max(prec1.to(device).item(), best_prec1)
            if is_best:
                print(
                    '\n*********************************\nBest so far is :\t',
                    best_prec1)

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'base_state_dict': basemodel.state_dict(),
                    'attention_state_dict': attention_model.state_dict(),
                    'region_state_dict': region_model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best.item(), 'checkpoint.pth.tar', fold)
コード例 #25
0
ファイル: utils.py プロジェクト: zhuangwang93/mergeComp
def get_network(args):
    """ return given network
    """

    if args.model == 'vgg16':
        from models.vgg import vgg16_bn
        model = vgg16_bn()
    elif args.model == 'vgg13':
        from models.vgg import vgg13_bn
        model = vgg13_bn()
    elif args.model == 'vgg11':
        from models.vgg import vgg11_bn
        model = vgg11_bn()
    elif args.model == 'vgg19':
        from models.vgg import vgg19_bn
        model = vgg19_bn()
    elif args.model == 'densenet121':
        from models.densenet import densenet121
        model = densenet121()
    elif args.model == 'densenet161':
        from models.densenet import densenet161
        model = densenet161()
    elif args.model == 'densenet169':
        from models.densenet import densenet169
        model = densenet169()
    elif args.model == 'densenet201':
        from models.densenet import densenet201
        model = densenet201()
    elif args.model == 'googlenet':
        from models.googlenet import googlenet
        model = googlenet()
    elif args.model == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        model = inceptionv3()
    elif args.model == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        model = inceptionv4()
    elif args.model == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        model = inception_resnet_v2()
    elif args.model == 'xception':
        from models.xception import xception
        model = xception()
    elif args.model == 'resnet18':
        from models.resnet import resnet18
        model = resnet18()
    elif args.model == 'resnet34':
        from models.resnet import resnet34
        model = resnet34()
    elif args.model == 'resnet50':
        from models.resnet import resnet50
        model = resnet50()
    elif args.model == 'resnet101':
        from models.resnet import resnet101
        model = resnet101()
    elif args.model == 'resnet152':
        from models.resnet import resnet152
        model = resnet152()
    elif args.model == 'preactresnet18':
        from models.preactresnet import preactresnet18
        model = preactresnet18()
    elif args.model == 'preactresnet34':
        from models.preactresnet import preactresnet34
        model = preactresnet34()
    elif args.model == 'preactresnet50':
        from models.preactresnet import preactresnet50
        model = preactresnet50()
    elif args.model == 'preactresnet101':
        from models.preactresnet import preactresnet101
        model = preactresnet101()
    elif args.model == 'preactresnet152':
        from models.preactresnet import preactresnet152
        model = preactresnet152()
    elif args.model == 'resnext50':
        from models.resnext import resnext50
        model = resnext50()
    elif args.model == 'resnext101':
        from models.resnext import resnext101
        model = resnext101()
    elif args.model == 'resnext152':
        from models.resnext import resnext152
        model = resnext152()
    elif args.model == 'shufflenet':
        from models.shufflenet import shufflenet
        model = shufflenet()
    elif args.model == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        model = shufflenetv2()
    elif args.model == 'squeezenet':
        from models.squeezenet import squeezenet
        model = squeezenet()
    elif args.model == 'mobilenet':
        from models.mobilenet import mobilenet
        model = mobilenet()
    elif args.model == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        model = mobilenetv2()
    elif args.model == 'nasnet':
        from models.nasnet import nasnet
        model = nasnet()
    elif args.model == 'attention56':
        from models.attention import attention56
        model = attention56()
    elif args.model == 'attention92':
        from models.attention import attention92
        model = attention92()
    elif args.model == 'seresnet18':
        from models.senet import seresnet18
        model = seresnet18()
    elif args.model == 'seresnet34':
        from models.senet import seresnet34
        model = seresnet34()
    elif args.model == 'seresnet50':
        from models.senet import seresnet50
        model = seresnet50()
    elif args.model == 'seresnet101':
        from models.senet import seresnet101
        model = seresnet101()
    elif args.model == 'seresnet152':
        from models.senet import seresnet152
        model = seresnet152()
    elif args.model == 'wideresnet':
        from models.wideresidual import wideresnet
        model = wideresnet()
    elif args.model == 'stochasticdepth18':
        from models.stochasticdepth import stochastic_depth_resnet18
        model = stochastic_depth_resnet18()
    elif args.model == 'stochasticdepth34':
        from models.stochasticdepth import stochastic_depth_resnet34
        model = stochastic_depth_resnet34()
    elif args.model == 'stochasticdepth50':
        from models.stochasticdepth import stochastic_depth_resnet50
        model = stochastic_depth_resnet50()
    elif args.model == 'stochasticdepth101':
        from models.stochasticdepth import stochastic_depth_resnet101
        model = stochastic_depth_resnet101()

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    return model
コード例 #26
0
    print('----------------------------------')

    unparsed = vars(ops)  # parse_args()方法的返回值为namespace,用vars()内建函数化为字典
    for key in unparsed.keys():
        print('{} : {}'.format(key, unparsed[key]))

    os.environ['CUDA_VISIBLE_DEVICES'] = ops.GPUS
    #---------------------------------------------------------------- 构建 landmarks 模型
    if ops.network == 'resnet_18':
        model_ = resnet18(num_classes=ops.num_classes,
                          img_size=ops.input_shape[2])
    elif ops.network == 'resnet_34':
        model_ = resnet34(num_classes=ops.num_classes,
                          img_size=ops.input_shape[2])
    elif ops.network == 'resnet_50':
        model_ = resnet50(num_classes=ops.num_classes,
                          img_size=ops.input_shape[2])
    elif ops.network == 'resnet_101':
        model_ = resnet101(num_classes=ops.num_classes,
                           img_size=ops.input_shape[2])
    elif ops.network == 'resnet_152':
        model_ = resnet152(num_classes=ops.num_classes,
                           img_size=ops.input_shape[2])
    elif ops.network == 'mobilenetv2':
        model_ = MobileNetV2(n_class=ops.num_classes,
                             input_size=ops.input_shape[2])
    else:
        print('error no the struct model : {}'.format(ops.network))

    dummy_input = torch.randn(ops.input_shape)
    flops, params = profile(model_, inputs=(dummy_input, ))
    print('flops : {} , params : {}'.format(flops, params))
コード例 #27
0
def get_model_dics(device, model_list= None):
    if model_list is None:
        model_list = ['densenet121', 'densenet161', 'resnet50', 'resnet152',
                      'incept_v1', 'incept_v3', 'inception_v4', 'incept_resnet_v2',
                      'incept_v4_adv2', 'incept_resnet_v2_adv2',
                      'black_densenet161','black_resnet50','black_incept_v3',
                      'old_vgg','old_res','old_incept']
    models = {}
    for model in model_list:
        if model=='densenet121':
            models['densenet121'] = densenet121(num_classes=110)
            load_model(models['densenet121'],"../pre_weights/ep_38_densenet121_val_acc_0.6527.pth",device)
        if model=='densenet161':
            models['densenet161'] = densenet161(num_classes=110)
            load_model(models['densenet161'],"../pre_weights/ep_30_densenet161_val_acc_0.6990.pth",device)
        if model=='resnet50':
            models['resnet50'] = resnet50(num_classes=110)
            load_model(models['resnet50'],"../pre_weights/ep_41_resnet50_val_acc_0.6900.pth",device)
        if model=='incept_v3':
            models['incept_v3'] = inception_v3(num_classes=110)
            load_model(models['incept_v3'],"../pre_weights/ep_36_inception_v3_val_acc_0.6668.pth",device)
        if model=='incept_v1':
            models['incept_v1'] = googlenet(num_classes=110)
            load_model(models['incept_v1'],"../pre_weights/ep_33_googlenet_val_acc_0.7091.pth",device)
    #vgg16 = vgg16_bn(num_classes=110)
    #load_model(vgg16, "./pre_weights/ep_30_vgg16_bn_val_acc_0.7282.pth",device)
        if model=='incept_resnet_v2':
            models['incept_resnet_v2'] = InceptionResNetV2(num_classes=110)  
            load_model(models['incept_resnet_v2'], "../pre_weights/ep_17_InceptionResNetV2_ori_0.8320.pth",device)

        if model=='incept_v4':
            models['incept_v4'] = InceptionV4(num_classes=110)
            load_model(models['incept_v4'],"../pre_weights/ep_17_InceptionV4_ori_0.8171.pth",device)
        if model=='incept_resnet_v2_adv':
            models['incept_resnet_v2_adv'] = InceptionResNetV2(num_classes=110)  
            load_model(models['incept_resnet_v2_adv'], "../pre_weights/ep_22_InceptionResNetV2_val_acc_0.8214.pth",device)

        if model=='incept_v4_adv':
            models['incept_v4_adv'] = InceptionV4(num_classes=110)
            load_model(models['incept_v4_adv'],"../pre_weights/ep_24_InceptionV4_val_acc_0.6765.pth",device)
        if model=='incept_resnet_v2_adv2':
            models['incept_resnet_v2_adv2'] = InceptionResNetV2(num_classes=110)  
            #load_model(models['incept_resnet_v2_adv2'], "../test_weights/ep_29_InceptionResNetV2_adv2_0.8115.pth",device)
            load_model(models['incept_resnet_v2_adv2'], "../test_weights/ep_13_InceptionResNetV2_val_acc_0.8889.pth",device)

        if model=='incept_v4_adv2':
            models['incept_v4_adv2'] = InceptionV4(num_classes=110)
#            load_model(models['incept_v4_adv2'],"../test_weights/ep_32_InceptionV4_adv2_0.7579.pth",device)
            load_model(models['incept_v4_adv2'],"../test_weights/ep_50_InceptionV4_val_acc_0.8295.pth",device)

        if model=='resnet152':
            models['resnet152'] = resnet152(num_classes=110)
            load_model(models['resnet152'],"../pre_weights/ep_14_resnet152_ori_0.6956.pth",device)
        if model=='resnet152_adv':
            models['resnet152_adv'] = resnet152(num_classes=110)
            load_model(models['resnet152_adv'],"../pre_weights/ep_29_resnet152_adv_0.6939.pth",device)
        if model=='resnet152_adv2':
            models['resnet152_adv2'] = resnet152(num_classes=110)
            load_model(models['resnet152_adv2'],"../pre_weights/ep_31_resnet152_adv2_0.6931.pth",device)



        if model=='black_resnet50':
            models['black_resnet50'] = resnet50(num_classes=110)
            load_model(models['black_resnet50'],"../test_weights/ep_0_resnet50_val_acc_0.7063.pth",device)
        if model=='black_densenet161':
            models['black_densenet161'] = densenet161(num_classes=110)
            load_model(models['black_densenet161'],"../test_weights/ep_4_densenet161_val_acc_0.6892.pth",device)
        if model=='black_incept_v3':
            models['black_incept_v3']=inception_v3(num_classes=110)
            load_model(models['black_incept_v3'],"../test_weights/ep_28_inception_v3_val_acc_0.6680.pth",device)
        if model=='old_res':
            MainModel = imp.load_source('MainModel', "./models_old/tf_to_pytorch_resnet_v1_50.py")
            models['old_res'] = torch.load('./models_old/tf_to_pytorch_resnet_v1_50.pth').to(device)
        if model=='old_vgg':
            MainModel = imp.load_source('MainModel', "./models_old/tf_to_pytorch_vgg16.py")
            models[model] = torch.load('./models_old/tf_to_pytorch_vgg16.pth').to(device)
        if model=='old_incept':
            MainModel = imp.load_source('MainModel', "./models_old/tf_to_pytorch_inception_v1.py")
            models[model]  = torch.load('./models_old/tf_to_pytorch_inception_v1.pth').to(device)
       
    return models
コード例 #28
0
ファイル: test_new.py プロジェクト: tejas-gokhale/BIRD_Code
color_dict = [[255, 255, 255], [0, 0, 255], [255, 255, 255], [0, 128, 128],
              [255, 255, 255], [255, 165, 0], [255, 255, 255], [128, 0, 128],
              [255, 0, 0], [255, 255, 255], [255, 255, 0]]
color_3bit = [[0, 0, 0], [0, 1, 1], [0, 0, 0], [0, 1, 0], [0, 0, 0], [1, 1, 0],
              [0, 0, 0], [1, 0, 1], [0, 0, 1], [0, 0, 0], [1, 0, 0]]
color_list = []

for i in range(11):
    color_list.append([item / 255. for item in color_dict[i]])

# LOAD MODEL FROM CHECKPOINT
# ---------------------------
transform = transforms.Compose(
    [transforms.Resize(size), transforms.ToTensor()])
data_set = EbayColor('./blocks_data/', feat_size, transform)
model = resnet50(False, '', 11)
model.fc = torch.nn.Linear(512, 11)

# Evaluation mode for batch normalization freeze
model.eval()
for p in model.parameters():
    p.requires_grad = False

# Load Back bone Module
state_dict = torch.load(
    './checkpoint/lr05_b4_alpha10_beta5/Color_pretrain_regression_96.pth'
)['state_dict']
new_params = model.state_dict()
new_params.update(state_dict)
model.load_state_dict(new_params)
model = model.cuda()
コード例 #29
0
ファイル: train.py プロジェクト: YangfeiLiu/Classify

for ds in dataset:
    data_path = os.path.join(args.root, ds)
    cls = [
        x for x in os.listdir(data_path)
        if os.path.isdir(os.path.join(data_path, x))
    ]
    num_class = len(cls)
    models = {
        "vgg16": vgg.vgg16_bn(num_class),
        "vgg19": vgg.vgg19_bn(num_class),
        "densenet121": densenet.densenet121(num_class),
        "densenet161": densenet.densenet161(num_class),
        "resnet34": resnet.resnet34(num_class),
        "resnet50": resnet.resnet50(num_class),
        "resnet101": resnet.resnet101(num_class),
        "seresnet34": senet.seresnet34(num_class),
        "seresnet50": senet.seresnet50(num_class),
        "seresnet101": senet.seresnet101(num_class),
        "resnext34": resnext.resnext34(num_class),
        "resnext50": resnext.resnext50(num_class),
        "resnext101": resnext.resnext101(num_class),
        "shufflenet": shufflenet.shufflenet(num_class),
        "xception": xception.xception(num_class)
    }
    for net_name in models.keys():
        writer = SummaryWriter('./runs/%s_%s/' % (ds, net_name))
        model = models[net_name]
        logger.add('./log/%s_%s_{time}.log' % (ds, net_name), level="INFO")
        logger.info("net:%s\t dataset:%s\t num_class:%d" %
コード例 #30
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet','senet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                k=opt.wide_resnet_k,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)

    elif opt.model == 'senet':
         assert opt.model_depth in [50,101,152,154,5032,10132]
         if opt.model_depth == 50:
             model = senet.se_resnet50(num_classes = opt.n_classes,  pretrained = None)
         elif opt.model_depth == 101:
             model = senet.se_resnet101(num_classes = opt.n_classes, pretrained = None)
         elif opt.model_depth == 152:
             model = senet.se_resnet152(num_classes = opt.n_classes, pretrained = None)
         elif opt.model_depth == 154:
             model = senet.senet154(num_classes = opt.n_classes, pretrained = None)
         elif opt.model_depth == 5032:
             model = senet.resnext50_32x4d(num_classes = opt.n_classes, pretrained = None)
         elif opt.model_depth == 10132:
             model = senet.se_resnext101_32x4d(num_classes = opt.n_classes, pretrained = None)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            #assert opt.arch == pretrain['arch']
            #model.load_state_dict(pretrain['state_dict'])
            pretrained_dict = pretrain['state_dict']
            model_dict = model.state_dict()
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if k.find("module.fc") == -1}
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

            #model.load_state_dict(model_dict,strict=False)

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features, opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            elif opt.model == "senet":
                model.module.last_linear = nn.Linear(model.module.last_linear.in_features, opt.n_finetune_classes)
                model.module.last_linear = model.module.last_linear.cuda()
                return model, model.parameters()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(
                    model.classifier.in_features, opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                            opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()
コード例 #31
0
ファイル: model.py プロジェクト: jarrelscy/3D-ResNets-PyTorch
def generate_model(opt):
    assert opt.model in [
        'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

        from models.resnet import get_fine_tuning_parameters

        if opt.model_depth == 10:
            model = resnet.resnet10(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'wideresnet':
        assert opt.model_depth in [50]

        from models.wide_resnet import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = wide_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                k=opt.wide_resnet_k,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'resnext':
        assert opt.model_depth in [50, 101, 152]

        from models.resnext import get_fine_tuning_parameters

        if opt.model_depth == 50:
            model = resnext.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = resnext.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = resnext.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                cardinality=opt.resnext_cardinality,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'preresnet':
        assert opt.model_depth in [18, 34, 50, 101, 152, 200]

        from models.pre_act_resnet import get_fine_tuning_parameters

        if opt.model_depth == 18:
            model = pre_act_resnet.resnet18(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 34:
            model = pre_act_resnet.resnet34(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 50:
            model = pre_act_resnet.resnet50(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 101:
            model = pre_act_resnet.resnet101(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 152:
            model = pre_act_resnet.resnet152(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 200:
            model = pre_act_resnet.resnet200(
                num_classes=opt.n_classes,
                shortcut_type=opt.resnet_shortcut,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
    elif opt.model == 'densenet':
        assert opt.model_depth in [121, 169, 201, 264]

        from models.densenet import get_fine_tuning_parameters

        if opt.model_depth == 121:
            model = densenet.densenet121(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 169:
            model = densenet.densenet169(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 201:
            model = densenet.densenet201(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)
        elif opt.model_depth == 264:
            model = densenet.densenet264(
                num_classes=opt.n_classes,
                sample_size=opt.sample_size,
                sample_duration=opt.sample_duration)

    if not opt.no_cuda:
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=None)

        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.module.classifier = nn.Linear(
                    model.module.classifier.in_features, opt.n_finetune_classes)
                model.module.classifier = model.module.classifier.cuda()
            else:
                model.module.fc = nn.Linear(model.module.fc.in_features,
                                            opt.n_finetune_classes)
                model.module.fc = model.module.fc.cuda()

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters
    else:
        if opt.pretrain_path:
            print('loading pretrained model {}'.format(opt.pretrain_path))
            pretrain = torch.load(opt.pretrain_path)
            assert opt.arch == pretrain['arch']

            model.load_state_dict(pretrain['state_dict'])

            if opt.model == 'densenet':
                model.classifier = nn.Linear(
                    model.classifier.in_features, opt.n_finetune_classes)
            else:
                model.fc = nn.Linear(model.fc.in_features,
                                            opt.n_finetune_classes)

            parameters = get_fine_tuning_parameters(model, opt.ft_begin_index)
            return model, parameters

    return model, model.parameters()