Example #1
0
    def __init__(self, num_classes, M=32, net='inception_mixed_6e', pretrained=False):
        super(WSDAN, self).__init__()
        self.num_classes = num_classes
        self.M = M
        self.net = net

        # Network Initialization
        if 'inception' in net:
            if net == 'inception_mixed_6e':
                self.features = inception_v3(pretrained=pretrained).get_features_mixed_6e()
                self.num_features = 768
            elif net == 'inception_mixed_7c':
                self.features = inception_v3(pretrained=pretrained).get_features_mixed_7c()
                self.num_features = 2048
            else:
                raise ValueError('Unsupported net: %s' % net)
        elif 'vgg' in net:
            self.features = getattr(vgg, net)(pretrained=pretrained).get_features()
            self.num_features = 512
        elif 'resnet' in net:
            self.features = getattr(resnet, net)(pretrained=pretrained).get_features()
            self.num_features = 512 * self.features[-1][-1].expansion
        else:
            raise ValueError('Unsupported net: %s' % net)

        # Attention Maps
        self.attentions = BasicConv2d(self.num_features, self.M, kernel_size=1)

        # Bilinear Attention Pooling
        self.bap = BAP(pool='GAP')

        # Classification Layer
        self.fc = nn.Linear(self.M * self.num_features, self.num_classes, bias=False)

        logging.info('WSDAN: using {} as feature extractor, num_classes: {}, num_attentions: {}'.format(net, self.num_classes, self.M))
Example #2
0
    def testNoBatchNormScaleByDefault(self):
        height, width = 299, 299
        num_classes = 1000
        inputs = tf.placeholder(tf.float32, (1, height, width, 3))
        with slim.arg_scope(inception.inception_v3_arg_scope()):
            inception.inception_v3(inputs, num_classes, is_training=False)

        self.assertEqual(tf.global_variables('.*/BatchNorm/gamma:0$'), [])
Example #3
0
    def testBatchNormScale(self):
        height, width = 299, 299
        num_classes = 1000
        inputs = tf.placeholder(tf.float32, (1, height, width, 3))
        with slim.arg_scope(
                inception.inception_v3_arg_scope(batch_norm_scale=True)):
            inception.inception_v3(inputs, num_classes, is_training=False)

        gamma_names = set(
            v.op.name for v in tf.global_variables('.*/BatchNorm/gamma:0$'))
        self.assertGreater(len(gamma_names), 0)
        for v in tf.global_variables('.*/BatchNorm/moving_mean:0$'):
            self.assertIn(v.op.name[:-len('moving_mean')] + 'gamma',
                          gamma_names)
Example #4
0
    def testRaiseValueErrorWithInvalidDepthMultiplier(self):
        batch_size = 5
        height, width = 299, 299
        num_classes = 1000

        inputs = tf.random_uniform((batch_size, height, width, 3))
        with self.assertRaises(ValueError):
            _ = inception.inception_v3(inputs,
                                       num_classes,
                                       depth_multiplier=-0.1)
        with self.assertRaises(ValueError):
            _ = inception.inception_v3(inputs,
                                       num_classes,
                                       depth_multiplier=0.0)
Example #5
0
def get_classifier(arch, num_classes=1000, pretrained=False):
    if arch == 'alexnet':
        model = models.alexnet(pretrained, num_classes=num_classes)
    elif arch == 'vgg16':
        model = models.vgg16(pretrained, num_classes=num_classes)
    elif arch == 'vgg16_bn':
        model = models.vgg16_bn(pretrained, num_classes=num_classes)
    elif arch == 'vgg19':
        model = models.vgg19(pretrained, num_classes=num_classes)
    elif arch == 'vgg19_bn':
        model = models.vgg19_bn(pretrained, num_classes=num_classes)
    elif arch == 'resnet18':
        model = models.resnet18(pretrained, num_classes=num_classes)
    elif arch == 'resnet34':
        model = models.resnet34(pretrained, num_classes=num_classes)
    elif arch == 'resnet50':
        model = models.resnet50(pretrained, num_classes=num_classes)
    elif arch == 'resnet101':
        model = models.resnet101(pretrained, num_classes=num_classes)
    elif arch == 'resnet152':
        model = models.resnet152(pretrained, num_classes=num_classes)
    elif arch == 'inception_v3':
        model = inception_v3(pretrained, num_classes=num_classes)
    else:
        raise NotImplementedError

    return model
Example #6
0
def get_model_dic(device):
    models = {}
    #densenet_121 = densenet121(num_classes=110)
    #load_model(densenet_121,"./pre_weights/ep_38_densenet121_val_acc_0.6527.pth",device)

    densenet_161 = densenet161(num_classes=110)
    load_model(densenet_161,
               "./pre_weights/ep_30_densenet161_val_acc_0.6990.pth", device)

    resnet_50 = resnet50(num_classes=110)
    load_model(resnet_50, "./pre_weights/ep_41_resnet50_val_acc_0.6900.pth",
               device)

    incept_v3 = inception_v3(num_classes=110)
    load_model(incept_v3,
               "./pre_weights/ep_36_inception_v3_val_acc_0.6668.pth", device)

    #incept_v1 = googlenet(num_classes=110)
    #load_model(incept_v1,"./pre_weights/ep_33_googlenet_val_acc_0.7091.pth",device)

    #vgg16 = vgg16_bn(num_classes=110)
    #load_model(vgg16, "./pre_weights/ep_30_vgg16_bn_val_acc_0.7282.pth",device)
    incept_resnet_v2_adv = InceptionResNetV2(num_classes=110)
    load_model(incept_resnet_v2_adv,
               "./pre_weights/ep_22_InceptionResNetV2_val_acc_0.8214.pth",
               device)

    incept_v4_adv = InceptionV4(num_classes=110)
    load_model(incept_v4_adv,
               "./pre_weights/ep_37_InceptionV4_val_acc_0.7119.pth", device)

    MainModel = imp.load_source('MainModel',
                                "./models_old/tf_to_pytorch_resnet_v1_50.py")
    resnet_model = torch.load(
        './models_old/tf_to_pytorch_resnet_v1_50.pth').to(device)

    MainModel = imp.load_source('MainModel',
                                "./models_old/tf_to_pytorch_vgg16.py")
    vgg_model = torch.load('./models_old/tf_to_pytorch_vgg16.pth').to(device)

    MainModel = imp.load_source('MainModel',
                                "./models_old/tf_to_pytorch_inception_v1.py")
    inception_model = torch.load(
        './models_old/tf_to_pytorch_inception_v1.pth').to(device)

    models = {  #"densenet121":densenet_121,
        "densenet161": densenet_161,
        "resnet_50": resnet_50,
        #      "incept_v1":incept_v1,
        "incept_v3": incept_v3,
        "incept_resnet_v2_adv": incept_resnet_v2_adv,
        "incept_v4_adv": incept_v4_adv,
        #"vgg16":vgg16
        "old_incept": inception_model,
        "old_res": resnet_model,
        "old_vgg": vgg_model
    }

    return models
Example #7
0
    def testTrainEvalWithReuse(self):
        train_batch_size = 5
        eval_batch_size = 2
        height, width = 150, 150
        num_classes = 1000

        train_inputs = tf.random_uniform((train_batch_size, height, width, 3))
        inception.inception_v3(train_inputs, num_classes)
        eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3))
        logits, _ = inception.inception_v3(eval_inputs,
                                           num_classes,
                                           is_training=False,
                                           reuse=True)
        predictions = tf.argmax(logits, 1)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            output = sess.run(predictions)
            self.assertEquals(output.shape, (eval_batch_size, ))
Example #8
0
    def testBuildPreLogitsNetwork(self):
        batch_size = 5
        height, width = 299, 299
        num_classes = None

        inputs = tf.random_uniform((batch_size, height, width, 3))
        net, end_points = inception.inception_v3(inputs, num_classes)
        self.assertTrue(net.op.name.startswith('InceptionV3/Logits/AvgPool'))
        self.assertListEqual(net.get_shape().as_list(),
                             [batch_size, 1, 1, 2048])
        self.assertFalse('Logits' in end_points)
        self.assertFalse('Predictions' in end_points)
Example #9
0
    def testLogitsNotSqueezed(self):
        num_classes = 25
        images = tf.random_uniform([1, 299, 299, 3])
        logits, _ = inception.inception_v3(images,
                                           num_classes=num_classes,
                                           spatial_squeeze=False)

        with self.test_session() as sess:
            tf.global_variables_initializer().run()
            logits_out = sess.run(logits)
            self.assertListEqual(list(logits_out.shape),
                                 [1, 1, 1, num_classes])
Example #10
0
    def testHalfSizeImages(self):
        batch_size = 5
        height, width = 150, 150
        num_classes = 1000

        inputs = tf.random_uniform((batch_size, height, width, 3))
        logits, end_points = inception.inception_v3(inputs, num_classes)
        self.assertTrue(logits.op.name.startswith('InceptionV3/Logits'))
        self.assertListEqual(logits.get_shape().as_list(),
                             [batch_size, num_classes])
        pre_pool = end_points['Mixed_7c']
        self.assertListEqual(pre_pool.get_shape().as_list(),
                             [batch_size, 3, 3, 2048])
Example #11
0
def get_network(args):

    if args.net == 'vgg16':
        from models.vgg import vgg16
        model_ft = vgg16(args.num_classes, export_onnx=args.export_onnx)
    elif args.net == 'alexnet':
        from models.alexnet import alexnet
        model_ft = alexnet(num_classes=args.num_classes,
                           export_onnx=args.export_onnx)
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet_v2
        model_ft = mobilenet_v2(pretrained=True, export_onnx=args.export_onnx)
    elif args.net == 'vgg19':
        from models.vgg import vgg19
        model_ft = vgg19(args.num_classes, export_onnx=args.export_onnx)
    else:
        if args.net == 'googlenet':
            from models.googlenet import googlenet
            model_ft = googlenet(pretrained=True)
        elif args.net == 'inception':
            from models.inception import inception_v3
            model_ft = inception_v3(args,
                                    pretrained=True,
                                    export_onnx=args.export_onnx)
        elif args.net == 'resnet18':
            from models.resnet import resnet18
            model_ft = resnet18(pretrained=True, export_onnx=args.export_onnx)
        elif args.net == 'resnet34':
            from models.resnet import resnet34
            model_ft = resnet34(pretrained=True, export_onnx=args.export_onnx)
        elif args.net == 'resnet101':
            from models.resnet import resnet101
            model_ft = resnet101(pretrained=True, export_onnx=args.export_onnx)
        elif args.net == 'resnet50':
            from models.resnet import resnet50
            model_ft = resnet50(pretrained=True, export_onnx=args.export_onnx)
        elif args.net == 'resnet152':
            from models.resnet import resnet152
            model_ft = resnet152(pretrained=True, export_onnx=args.export_onnx)
        else:
            print("The %s is not supported..." % (args.net))
            return
    if args.net == 'mobilenet':
        num_ftrs = model_ft.classifier[1].in_features
        model_ft.classifier[1] = nn.Linear(num_ftrs * 4, args.num_classes)
    else:
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, args.num_classes)
    net = model_ft

    return net
Example #12
0
    def testBuildClassificationNetwork(self):
        batch_size = 5
        height, width = 299, 299
        num_classes = 1000

        inputs = tf.random_uniform((batch_size, height, width, 3))
        logits, end_points = inception.inception_v3(inputs, num_classes)
        self.assertTrue(
            logits.op.name.startswith('InceptionV3/Logits/SpatialSqueeze'))
        self.assertListEqual(logits.get_shape().as_list(),
                             [batch_size, num_classes])
        self.assertTrue('Predictions' in end_points)
        self.assertListEqual(end_points['Predictions'].get_shape().as_list(),
                             [batch_size, num_classes])
Example #13
0
def main():
    print('creating backbone')
    backbone = inception_v3().to(device)

    print('creating classifier')
    classifier = ClassifierA(feature_dim=__CLASSIFIER_INPUT,
                             embed_dim=__EMBED_DIM,
                             dropout_ratio=__DROPOUT_RATIO).to(device)

    # Initialize the models for this run
    # backbone_dict = backbone.state_dict()
    # pretrained_dict = torch.load(__OLD_BACKBONE_PATH)
    # pretrained_dict = {
    #     k: v
    #     for k, v in pretrained_dict.items() if k in backbone_dict
    # }
    # print('loading backbone keys {} from {}'.format(pretrained_dict.keys(),
    #                                                 __OLD_BACKBONE_PATH))
    # backbone_dict.update(pretrained_dict)
    # backbone.load_state_dict(backbone_dict)

    print('loading backbone from ' + __OLD_BACKBONE_PATH)
    backbone.load_state_dict(torch.load(__OLD_BACKBONE_PATH))
    print('loading classifier from ' + __OLD_CLASSIFIER_PATH)
    classifier.load_state_dict(torch.load(__OLD_CLASSIFIER_PATH))

    backbone = nn.DataParallel(backbone, device_ids=device_ids)
    classifier = nn.DataParallel(classifier, device_ids=device_ids)

    dataloaders_dict = {
        x: DataLoader(VolleyballDataset(x),
                      batch_size=__BATCH_SIZE,
                      shuffle=True,
                      num_workers=2,
                      collate_fn=collate_fn)
        for x in ['trainval', 'test']
    }
    criterion_dict = {
        'action': nn.CrossEntropyLoss(weight=__ACTION_WEIGHT),
        'activity': nn.CrossEntropyLoss()
    }

    params = list(backbone.parameters()) + list(classifier.parameters())
    optimizer = optim.Adam(params,
                           lr=__LEARNING_RATE,
                           weight_decay=__WEIGHT_DECAY)

    train_model(backbone, classifier, dataloaders_dict, criterion_dict,
                optimizer)
Example #14
0
    def testEvaluation(self):
        batch_size = 2
        height, width = 299, 299
        num_classes = 1000

        eval_inputs = tf.random_uniform((batch_size, height, width, 3))
        logits, _ = inception.inception_v3(eval_inputs,
                                           num_classes,
                                           is_training=False)
        predictions = tf.argmax(logits, 1)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            output = sess.run(predictions)
            self.assertEquals(output.shape, (batch_size, ))
Example #15
0
    def testUnknowBatchSize(self):
        batch_size = 1
        height, width = 299, 299
        num_classes = 1000

        inputs = tf.placeholder(tf.float32, (None, height, width, 3))
        logits, _ = inception.inception_v3(inputs, num_classes)
        self.assertTrue(logits.op.name.startswith('InceptionV3/Logits'))
        self.assertListEqual(logits.get_shape().as_list(), [None, num_classes])
        images = tf.random_uniform((batch_size, height, width, 3))

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            output = sess.run(logits, {inputs: images.eval()})
            self.assertEquals(output.shape, (batch_size, num_classes))
Example #16
0
    def testBuildEndPointsWithDepthMultiplierGreaterThanOne(self):
        batch_size = 5
        height, width = 299, 299
        num_classes = 1000

        inputs = tf.random_uniform((batch_size, height, width, 3))
        _, end_points = inception.inception_v3(inputs, num_classes)

        endpoint_keys = [
            key for key in end_points.keys()
            if key.startswith('Mixed') or key.startswith('Conv')
        ]

        _, end_points_with_multiplier = inception.inception_v3(
            inputs,
            num_classes,
            scope='depth_multiplied_net',
            depth_multiplier=2.0)

        for key in endpoint_keys:
            original_depth = end_points[key].get_shape().as_list()[3]
            new_depth = end_points_with_multiplier[key].get_shape().as_list(
            )[3]
            self.assertEqual(2.0 * original_depth, new_depth)
Example #17
0
def get_network(args,cfg):
    """ return given network
    """
    # pdb.set_trace()
    if args.net == 'lenet5':
        net = LeNet5().cuda()
    elif args.net == 'alexnet':
        net = alexnet(pretrained=args.pretrain, num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg16':
        net = vgg16(pretrained=args.pretrain, num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg13':
        net = vgg13(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg11':
        net = vgg11(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg19':
        net = vgg19(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg16_bn':
        net = vgg16_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg13_bn':
        net = vgg13_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg11_bn':
        net = vgg11_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg19_bn':
        net = vgg19_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net =='inceptionv3':
        net = inception_v3().cuda()
    # elif args.net == 'inceptionv4':
    #     net = inceptionv4().cuda()
    # elif args.net == 'inceptionresnetv2':
    #     net = inception_resnet_v2().cuda()
    elif args.net == 'resnet18':
        net = resnet18(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda(args.gpuid)
    elif args.net == 'resnet34':
        net = resnet34(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'resnet50':
        net = resnet50(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda(args.gpuid)
    elif args.net == 'resnet101':
        net = resnet101(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'resnet152':
        net = resnet152(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'squeezenet':
        net = squeezenet1_0().cuda()
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    return net
Example #18
0
 def testUnknownImageShape(self):
     tf.reset_default_graph()
     batch_size = 2
     height, width = 299, 299
     num_classes = 1000
     input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
     with self.test_session() as sess:
         inputs = tf.placeholder(tf.float32,
                                 shape=(batch_size, None, None, 3))
         logits, end_points = inception.inception_v3(inputs, num_classes)
         self.assertListEqual(logits.get_shape().as_list(),
                              [batch_size, num_classes])
         pre_pool = end_points['Mixed_7c']
         feed_dict = {inputs: input_np}
         tf.global_variables_initializer().run()
         pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
         self.assertListEqual(list(pre_pool_out.shape),
                              [batch_size, 8, 8, 2048])
Example #19
0
    def testBuildEndPoints(self):
        batch_size = 5
        height, width = 299, 299
        num_classes = 1000

        inputs = tf.random_uniform((batch_size, height, width, 3))
        _, end_points = inception.inception_v3(inputs, num_classes)
        self.assertTrue('Logits' in end_points)
        logits = end_points['Logits']
        self.assertListEqual(logits.get_shape().as_list(),
                             [batch_size, num_classes])
        self.assertTrue('AuxLogits' in end_points)
        aux_logits = end_points['AuxLogits']
        self.assertListEqual(aux_logits.get_shape().as_list(),
                             [batch_size, num_classes])
        self.assertTrue('Mixed_7c' in end_points)
        pre_pool = end_points['Mixed_7c']
        self.assertListEqual(pre_pool.get_shape().as_list(),
                             [batch_size, 8, 8, 2048])
        self.assertTrue('PreLogits' in end_points)
        pre_logits = end_points['PreLogits']
        self.assertListEqual(pre_logits.get_shape().as_list(),
                             [batch_size, 1, 1, 2048])
Example #20
0
def get_model(args):

    assert args.model in [
        'derpnet', 'alexnet', 'resnet', 'vgg', 'vgg_attn', 'inception'
    ]

    if args.model == 'alexnet':
        model = alexnet.alexnet(pretrained=args.pretrained,
                                n_channels=args.n_channels,
                                num_classes=args.n_classes)
    elif args.model == 'inception':
        model = inception.inception_v3(pretrained=args.pretrained,
                                       aux_logits=False,
                                       progress=True,
                                       num_classes=args.n_classes)
    elif args.model == 'vgg':
        assert args.model_depth in [11, 13, 16, 19]

        if args.model_depth == 11:
            model = vgg.vgg11_bn(pretrained=args.pretrained,
                                 progress=True,
                                 num_classes=args.n_classes)
        if args.model_depth == 13:
            model = vgg.vgg13_bn(pretrained=args.pretrained,
                                 progress=True,
                                 num_classes=args.n_classes)
        if args.model_depth == 16:
            model = vgg.vgg16_bn(pretrained=args.pretrained,
                                 progress=True,
                                 num_classes=args.n_classes)
        if args.model_depth == 19:
            model = vgg.vgg19(pretrained=args.pretrained,
                              progress=True,
                              num_classes=args.n_classes)

    elif args.model == 'vgg_attn':
        assert args.model_depth in [11, 13, 16, 19]

        if args.model_depth == 11:
            model = vgg_attn.vgg11_bn(pretrained=args.pretrained,
                                      progress=True,
                                      num_classes=args.n_classes)
        if args.model_depth == 13:
            model = vgg_attn.vgg11_bn(pretrained=args.pretrained,
                                      progress=True,
                                      num_classes=args.n_classes)
        if args.model_depth == 16:
            model = vgg_attn.vgg11_bn(pretrained=args.pretrained,
                                      progress=True,
                                      num_classes=args.n_classes)
        if args.model_depth == 19:
            model = vgg_attn.vgg11_bn(pretrained=args.pretrained,
                                      progress=True,
                                      num_classes=args.n_classes)

    elif args.model == 'derpnet':
        model = derp_net.Net(n_channels=args.n_channels,
                             num_classes=args.n_classes)

    elif args.model == 'resnet':
        assert args.model_depth in [10, 18, 34, 50, 101, 152, 200]

        if args.model_depth == 10:
            model = resnet.resnet10(pretrained=args.pretrained,
                                    num_classes=args.n_classes)
        elif args.model_depth == 18:
            model = resnet.resnet18(pretrained=args.pretrained,
                                    num_classes=args.n_classes)
        elif args.model_depth == 34:
            model = resnet.resnet34(pretrained=args.pretrained,
                                    num_classes=args.n_classes)
        elif args.model_depth == 50:
            model = resnet.resnet50(pretrained=args.pretrained,
                                    num_classes=args.n_classes)
        elif args.model_depth == 101:
            model = resnet.resnet101(pretrained=args.pretrained,
                                     num_classes=args.n_classes)
        elif args.model_depth == 152:
            model = resnet.resnet152(pretrained=args.pretrained,
                                     num_classes=args.n_classes)
        elif args.model_depth == 200:
            model = resnet.resnet200(pretrained=args.pretrained,
                                     num_classes=args.n_classes)

    if args.pretrained and args.pretrain_path and not args.model == 'alexnet' and not args.model == 'vgg' and not args.model == 'resnet':

        print('loading pretrained model {}'.format(args.pretrain_path))
        pretrain = torch.load(args.pretrain_path)
        assert args.arch == pretrain['arch']

        # here all the magic happens: need to pick the parameters which will be adjusted during training
        # the rest of the params will be frozen
        pretrain_dict = {
            key[7:]: value
            for key, value in pretrain['state_dict'].items()
            if key[7:9] != 'fc'
        }
        from collections import OrderedDict
        pretrain_dict = OrderedDict(pretrain_dict)

        # https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance
        import types
        model.load_state_dict = types.MethodType(load_my_state_dict, model)

        old_dict = copy.deepcopy(
            model.state_dict())  # normal copy() just gives a reference
        model.load_state_dict(pretrain_dict)
        new_dict = model.state_dict()

        num_features = model.fc.in_features
        if args.model == 'densenet':
            model.classifier = nn.Linear(num_features, args.n_classes)
        else:
            #model.fc = nn.Sequential(nn.Linear(num_features, 1028), nn.ReLU(), nn.Dropout(0.5), nn.Linear(1028, args.n_finetune_classes))
            model.fc = nn.Linear(num_features, args.n_classes)

        # parameters = get_fine_tuning_parameters(model, args.ft_begin_index)
        parameters = model.parameters()  # fine-tunining EVERYTHIIIIIANG
        # parameters = model.fc.parameters()  # fine-tunining ONLY FC layer
        return model, parameters

    return model, model.parameters()
Example #21
0
def train_model(modname='alexnet', pm_ch='both', bs=16):
    """
    Args:
        modname (string): Name of the model. Has to be one of the values:
            'alexnet', batch 64
            'densenet'
            'inception'
            'resnet', batch 16
            'squeezenet', batch 16
            'vgg'
        pm_ch (string): pixelmap channel -- 'time', 'charge', 'both', default to both
    """
    # device configuration
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # hyper parameters
    max_epochs = 10
    learning_rate = 0.001

    # determine number of input channels
    nch = 2
    if pm_ch != 'both':
        nch = 1

    ds = PixelMapDataset('training_file_list.txt', pm_ch)
    # try out the data loader utility
    dl = torch.utils.data.DataLoader(dataset=ds, batch_size=bs, shuffle=True)

    # define model
    model = None
    if modname == 'alexnet':
        model = alexnet(num_classes=3, in_ch=nch).to(device)
    elif modname == 'densenet':
        model = DenseNet(num_classes=3, in_ch=nch).to(device)
    elif modname == 'inception':
        model = inception_v3(num_classes=3, in_ch=nch).to(device)
    elif modname == 'resnet':
        model = resnet18(num_classes=3, in_ch=nch).to(device)
    elif modname == 'squeezenet':
        model = squeezenet1_1(num_classes=3, in_ch=nch).to(device)
    elif modname == 'vgg':
        model = vgg19_bn(in_ch=nch, num_classes=3).to(device)
    else:
        print('Model {} not defined.'.format(modname))
        return

    # loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # training process
    total_step = len(dl)
    for epoch in range(max_epochs):
        for i, (view1, view2, local_labels) in enumerate(dl):
            view1 = view1.float().to(device)
            if modname == 'inception':
                view1 = nn.ZeroPad2d((0, 192, 102, 101))(view1)
            else:
                view1 = nn.ZeroPad2d((0, 117, 64, 64))(view1)
            local_labels = local_labels.to(device)

            # forward pass
            outputs = model(view1)
            loss = criterion(outputs, local_labels)

            # backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (i + 1) % bs == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                    epoch + 1, max_epochs, i + 1, total_step, loss.item()))

    # save the model checkpoint
    save_path = '../../../data/two_views/saved_models/{}/{}'.format(
        modname, pm_ch)
    os.makedirs(save_path, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(save_path, 'model.ckpt'))
Example #22
0
def get_model(class_num):
    if (MODEL_TYPE == 'alexnet'):
        model = alexnet.alexnet(pretrained=FINETUNE)
    elif (MODEL_TYPE == 'vgg'):
        if (MODEL_DEPTH_OR_VERSION == 11):
            model = vgg.vgg11(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 13):
            model = vgg.vgg13(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 16):
            model = vgg.vgg16(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 19):
            model = vgg.vgg19(pretrained=FINETUNE)
        else:
            print('Error : VGG should have depth of either [11, 13, 16, 19]')
            sys.exit(1)
    elif (MODEL_TYPE == 'squeezenet'):
        if (MODEL_DEPTH_OR_VERSION == 0 or MODEL_DEPTH_OR_VERSION == 'v0'):
            model = squeezenet.squeezenet1_0(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 1 or MODEL_DEPTH_OR_VERSION == 'v1'):
            model = squeezenet.squeezenet1_1(pretrained=FINETUNE)
        else:
            print('Error : Squeezenet should have version of either [0, 1]')
            sys.exit(1)
    elif (MODEL_TYPE == 'resnet'):
        if (MODEL_DEPTH_OR_VERSION == 18):
            model = resnet.resnet18(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 34):
            model = resnet.resnet34(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 50):
            model = resnet.resnet50(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 101):
            model = resnet.resnet101(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 152):
            model = resnet.resnet152(pretrained=FINETUNE)
        else:
            print(
                'Error : Resnet should have depth of either [18, 34, 50, 101, 152]'
            )
            sys.exit(1)
    elif (MODEL_TYPE == 'densenet'):
        if (MODEL_DEPTH_OR_VERSION == 121):
            model = densenet.densenet121(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 169):
            model = densenet.densenet169(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 161):
            model = densenet.densenet161(pretrained=FINETUNE)
        elif (MODEL_DEPTH_OR_VERSION == 201):
            model = densenet.densenet201(pretrained=FINETUNE)
        else:
            print(
                'Error : Densenet should have depth of either [121, 169, 161, 201]'
            )
            sys.exit(1)
    elif (MODEL_TYPE == 'inception'):
        if (MODEL_DEPTH_OR_VERSION == 3 or MODEL_DEPTH_OR_VERSION == 'v3'):
            model = inception.inception_v3(pretrained=FINETUNE)
        else:
            print('Error : Inception should have version of either [3, ]')
            sys.exit(1)
    else:
        print(
            'Error : Network should be either [alexnet / squeezenet / vgg / resnet / densenet / inception]'
        )
        sys.exit(1)

    if (MODEL_TYPE == 'alexnet' or MODEL_TYPE == 'vgg'):
        num_ftrs = model.classifier[6].in_features
        feature_model = list(model.classifier.children())
        feature_model.pop()
        feature_model.append(nn.Linear(num_ftrs, class_num))
        model.classifier = nn.Sequential(*feature_model)
    elif (MODEL_TYPE == 'resnet' or MODEL_TYPE == 'inception'):
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, class_num)
    elif (MODEL_TYPE == 'densenet'):
        num_ftrs = model.classifier.in_features
        model.classifier = nn.Linear(num_ftrs, class_num)

    return model
Example #23
0
def training():
    batch_size = 64
    epochs = 200
    fine_tune_epochs = 30
    lr = 0.1

    x_train, y_train, x_test, y_test, nb_classes = load_data(args.data)
    print('x_train shape:', x_train.shape)
    print(x_train.shape[0], 'train samples')
    print(x_test.shape[0], 'test samples')

    y_train = keras.utils.to_categorical(y_train, nb_classes)
    y_test = keras.utils.to_categorical(y_test, nb_classes)

    datagen = ImageDataGenerator(horizontal_flip=True,
                                 width_shift_range=5. / 32,
                                 height_shift_range=5. / 32)
    data_iter = datagen.flow(x_train,
                             y_train,
                             batch_size=batch_size,
                             shuffle=True)

    if args.model == 'resnet':
        model = resnet.resnet(nb_classes,
                              depth=args.depth,
                              wide_factor=args.wide_factor)
        save_name = 'resnet_{}_{}_{}'.format(args.depth, args.wide_factor,
                                             args.data)
    elif args.model == 'densenet':
        model = densenet.densenet(nb_classes,
                                  args.growth_rate,
                                  depth=args.depth)
        save_name = 'densenet_{}_{}_{}'.format(args.depth, args.growth_rate,
                                               args.data)
    elif args.model == 'inception':
        model = inception.inception_v3(nb_classes)
        save_name = 'inception_{}'.format(args.data)

    elif args.model == 'vgg':
        model = vggnet.vgg(nb_classes)
        save_name = 'vgg_{}'.format(args.data)
    else:
        raise ValueError('Does not support {}'.format(args.model))

    model.summary()
    learning_rate_scheduler = LearningRateScheduler(schedule=schedule)
    if args.model == 'vgg':
        callbacks = None
        model.compile(loss='categorical_crossentropy',
                      optimizer='adam',
                      metrics=['accuracy'])
    else:
        callbacks = [learning_rate_scheduler]
        model.compile(loss='categorical_crossentropy',
                      optimizer=keras.optimizers.SGD(lr=lr,
                                                     momentum=0.9,
                                                     nesterov=True),
                      metrics=['accuracy'])
    # pre-train
    history = model.fit_generator(data_iter,
                                  steps_per_epoch=x_train.shape[0] //
                                  batch_size,
                                  epochs=epochs,
                                  callbacks=callbacks,
                                  validation_data=(x_test, y_test))
    if not os.path.exists('./results/'):
        os.mkdir('./results/')
    save_history(history, './results/', save_name)
    model.save_weights('./results/{}_weights.h5'.format(save_name))

    # prune weights
    # save masks for weight layers
    masks = {}
    layer_count = 0
    # not compress first convolution layer
    first_conv = True
    for layer in model.layers:
        weight = layer.get_weights()
        if len(weight) >= 2:
            if not first_conv:
                w = deepcopy(weight)
                tmp, mask = prune_weights(w[0],
                                          compress_rate=args.compress_rate)
                masks[layer_count] = mask
                w[0] = tmp
                layer.set_weights(w)
            else:
                first_conv = False
        layer_count += 1
    # evaluate model after pruning
    score = model.evaluate(x_test, y_test, verbose=0)
    print('val loss: {}'.format(score[0]))
    print('val acc: {}'.format(score[1]))
    # fine-tune
    for i in range(fine_tune_epochs):
        for _ in range(x_train.shape[0] // batch_size):
            X, Y = data_iter.next()
            # train on each batch
            model.train_on_batch(X, Y)
            # apply masks
            for layer_id in masks:
                w = model.layers[layer_id].get_weights()
                w[0] = w[0] * masks[layer_id]
                model.layers[layer_id].set_weights(w)
        score = model.evaluate(x_test, y_test, verbose=0)
        print('val loss: {}'.format(score[0]))
        print('val acc: {}'.format(score[1]))

    # save compressed weights
    compressed_name = './results/compressed_{}_weights'.format(args.model)
    save_compressed_weights(model, compressed_name)
Example #24
0
def test_model(modname='alexnet', pm_ch='both', bs=16):
    # hyperparameters
    batch_size = bs

    # device configuration
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # determine number of input channels
    nch = 2
    if pm_ch != 'both':
        nch = 1

    # restore model
    model = None
    if modname == 'alexnet':
        model = alexnet(num_classes=3, in_ch=nch).to(device)
    elif modname == 'densenet':
        model = DenseNet(num_classes=3, in_ch=nch).to(device)
    elif modname == 'inception':
        model = inception_v3(num_classes=3, in_ch=nch).to(device)
    elif modname == 'resnet':
        model = resnet18(num_classes=3, in_ch=nch).to(device)
    elif modname == 'squeezenet':
        model = squeezenet1_1(num_classes=3, in_ch=nch).to(device)
    elif modname == 'vgg':
        model = vgg19_bn(in_ch=nch, num_classes=3).to(device)
    else:
        print('Model {} not defined.'.format(modname))
        return

    # retrieve trained model
    # load path
    load_path = '../../../data/two_views/saved_models/{}/{}'.format(
        modname, pm_ch)
    model_pathname = os.path.join(load_path, 'model.ckpt')
    if not os.path.exists(model_pathname):
        print('Trained model file {} does not exist. Abort.'.format(
            model_pathname))
        return
    model.load_state_dict(torch.load(model_pathname))

    # load test dataset
    test_dataset = PixelMapDataset('test_file_list.txt', pm_ch)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False)

    # test the model
    model.eval(
    )  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
    with torch.no_grad():
        correct = 0
        total = 0
        correct_cc_or_bkg = 0
        ws_total = 0
        ws_correct = 0
        for view1, view2, labels in test_loader:
            view1 = view1.float().to(device)
            if modname == 'inception':
                view1 = nn.ZeroPad2d((0, 192, 102, 101))(view1)
            else:
                view1 = nn.ZeroPad2d((0, 117, 64, 64))(view1)
            labels = labels.to(device)
            outputs = model(view1)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            for i in range(len(predicted)):
                if (predicted[i] < 2
                        and labels[i] < 2) or (predicted[i] == 2
                                               and labels[i] == 2):
                    correct_cc_or_bkg += 1
                if labels[i] < 2:
                    ws_total += 1
                    if (predicted[i] == labels[i]):
                        ws_correct += 1
        print('Model Performance:')
        print('Model:', modname)
        print('Channel:', pm_ch)
        print(
            '3-class Test Accuracy of the model on the test images: {}/{}, {:.2f} %'
            .format(correct, total, 100 * correct / total))
        print(
            '2-class Test Accuracy of the model on the test images: {}/{}, {:.2f} %'
            .format(correct_cc_or_bkg, total, 100 * correct_cc_or_bkg / total))
        print(
            'Wrong-sign Test Accuracy of the model on the test images: {}/{}, {:.2f} %'
            .format(ws_correct, ws_total, 100 * ws_correct / ws_total))
Example #25
0
def main():
    global args, best_err1
    args = parser.parse_args()

    # TensorBoard configure
    if args.tensorboard:
        configure('%s_checkpoints/%s'%(args.dataset, args.expname))

    # CUDA
    os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_ids)
    if torch.cuda.is_available():
        cudnn.benchmark = True  # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        kwargs = {'num_workers': 2, 'pin_memory': True}
    else:
        kwargs = {'num_workers': 2}

    # Data loading code
    if args.dataset == 'cifar10':
        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                         std=[0.2023, 0.1994, 0.2010])
    elif args.dataset == 'cifar100':
        normalize = transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
                                         std=[0.2634, 0.2528, 0.2719])
    elif args.dataset == 'cub':
        normalize = transforms.Normalize(mean=[0.4862, 0.4973, 0.4293],
                                         std=[0.2230, 0.2185, 0.2472])
    elif args.dataset == 'webvision':
        normalize = transforms.Normalize(mean=[0.49274242, 0.46481857, 0.41779366],
                                         std=[0.26831809, 0.26145372, 0.27042758])
    else:
        raise Exception('Unknown dataset: {}'.format(args.dataset))

    # Transforms
    if args.augment:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(args.train_image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(args.train_image_size),
            transforms.ToTensor(),
            normalize,
        ])
    val_transform = transforms.Compose([
        transforms.Resize(args.test_image_size),
        transforms.CenterCrop(args.test_crop_image_size),
        transforms.ToTensor(),
        normalize
    ])

    # Datasets
    num_classes = 10    # default 10 classes
    if args.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10('./data/', train=True, download=True, transform=train_transform)
        val_dataset = datasets.CIFAR10('./data/', train=False, download=True, transform=val_transform)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_dataset = datasets.CIFAR100('./data/', train=True, download=True, transform=train_transform)
        val_dataset = datasets.CIFAR100('./data/', train=False, download=True, transform=val_transform)
        num_classes = 100
    elif args.dataset == 'cub':
        train_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/DuAngAng/datasets/CUB-200-2011/train/',
                                             transform=train_transform)
        val_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/DuAngAng/datasets/CUB-200-2011/test/',
                                           transform=val_transform)
        num_classes = 200
    elif args.dataset == 'webvision':
        train_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/LiuJing/WebVision/info/train',
                                             transform=train_transform)
        val_dataset = datasets.ImageFolder('/media/ouc/30bd7817-d3a1-4e83-b7d9-5c0e373ae434/LiuJing/WebVision/info/val',
                                           transform=val_transform)
        num_classes = 1000
    else:
        raise Exception('Unknown dataset: {}'.format(args.dataset))

    # Data Loader
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, **kwargs)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, **kwargs)

    # Create model
    if args.model == 'AlexNet':
        model = alexnet(pretrained=False, num_classes=num_classes)
    elif args.model == 'VGG':
        use_batch_normalization = True  # default use Batch Normalization
        if use_batch_normalization:
            if args.depth == 11:
                model = vgg11_bn(pretrained=False, num_classes=num_classes)
            elif args.depth == 13:
                model = vgg13_bn(pretrained=False, num_classes=num_classes)
            elif args.depth == 16:
                model = vgg16_bn(pretrained=False, num_classes=num_classes)
            elif args.depth == 19:
                model = vgg19_bn(pretrained=False, num_classes=num_classes)
            else:
                raise Exception('Unsupport VGG detph: {}, optional depths: 11, 13, 16 or 19'.format(args.depth))
        else:
            if args.depth == 11:
                model = vgg11(pretrained=False, num_classes=num_classes)
            elif args.depth == 13:
                model = vgg13(pretrained=False, num_classes=num_classes)
            elif args.depth == 16:
                model = vgg16(pretrained=False, num_classes=num_classes)
            elif args.depth == 19:
                model = vgg19(pretrained=False, num_classes=num_classes)
            else:
                raise Exception('Unsupport VGG detph: {}, optional depths: 11, 13, 16 or 19'.format(args.depth))
    elif args.model == 'Inception':
        model = inception_v3(pretrained=False, num_classes=num_classes)
    elif args.model == 'ResNet':
        if args.depth == 18:
            model = resnet18(pretrained=False, num_classes=num_classes)
        elif args.depth == 34:
            model = resnet34(pretrained=False, num_classes=num_classes)
        elif args.depth == 50:
            model = resnet50(pretrained=False, num_classes=num_classes)
        elif args.depth == 101:
            model = resnet101(pretrained=False, num_classes=num_classes)
        elif args.depth == 152:
            model = resnet152(pretrained=False, num_classes=num_classes)
        else:
            raise Exception('Unsupport ResNet detph: {}, optional depths: 18, 34, 50, 101 or 152'.format(args.depth))
    elif args.model == 'MPN-COV-ResNet':
        if args.depth == 18:
            model = mpn_cov_resnet18(pretrained=False, num_classes=num_classes)
        elif args.depth == 34:
            model = mpn_cov_resnet34(pretrained=False, num_classes=num_classes)
        elif args.depth == 50:
            model = mpn_cov_resnet50(pretrained=False, num_classes=num_classes)
        elif args.depth == 101:
            model = mpn_cov_resnet101(pretrained=False, num_classes=num_classes)
        elif args.depth == 152:
            model = mpn_cov_resnet152(pretrained=False, num_classes=num_classes)
        else:
            raise Exception('Unsupport MPN-COV-ResNet detph: {}, optional depths: 18, 34, 50, 101 or 152'.format(args.depth))
    else:
        raise Exception('Unsupport model'.format(args.model))

    # Get the number of model parameters
    print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))

    if torch.cuda.is_available():
        model = model.cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("==> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_err1 = checkpoint['best_err1']
            model.load_state_dict(checkpoint['state_dict'])
            print("==> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("==> no checkpoint found at '{}'".format(args.resume))

    print(model)

    # Define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    if torch.cuda.is_available():
        criterion = criterion.cuda()
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # Train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # Evaluate on validation set
        err1 = validate(val_loader, model, criterion, epoch)

        # Remember best err1 and save checkpoint
        is_best = (err1 <= best_err1)
        best_err1 = min(err1, best_err1)
        print("Current best accuracy (error):", best_err1)
        save_checkpoint({
            'epoch': epoch+1,
            'state_dict': model.state_dict(),
            'best_err1': best_err1,
        }, is_best)

    print("Best accuracy (error):", best_err1)
Example #26
0
                        help='how many frames to sampler per video')

    parser.add_argument("--video_path",
                        dest='video_path',
                        type=str,
                        default='data/train-video',
                        help='path to video dataset')
    parser.add_argument("--model",
                        dest="model",
                        type=str,
                        default='resnet152',
                        help='the CNN model you want to use to extract_feats')
    parser.add_argument('--dim_vid',
                        dest='dim_vid',
                        type=int,
                        default=2048,
                        help='dim of frames images extracted by cnn model')

    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    params = vars(args)
    if params['model'] == 'inception_v3':
        model = inception_v3(pretrained=True).cuda()
    elif params['model'] == 'resnet152':
        model = torchvision.models.resnet152(pretrained=True)
        model = nn.Sequential(*list(model.children())[:-1])
        model = model.cuda()
    else:
        print("doesn't support %s" % (params['model']))
    extract_feats(params, model)
def get_model_dics(device, model_list= None):
    if model_list is None:
        model_list = ['densenet121', 'densenet161', 'resnet50', 'resnet152',
                      'incept_v1', 'incept_v3', 'inception_v4', 'incept_resnet_v2',
                      'incept_v4_adv2', 'incept_resnet_v2_adv2',
                      'black_densenet161','black_resnet50','black_incept_v3',
                      'old_vgg','old_res','old_incept']
    models = {}
    for model in model_list:
        if model=='densenet121':
            models['densenet121'] = densenet121(num_classes=110)
            load_model(models['densenet121'],"../pre_weights/ep_38_densenet121_val_acc_0.6527.pth",device)
        if model=='densenet161':
            models['densenet161'] = densenet161(num_classes=110)
            load_model(models['densenet161'],"../pre_weights/ep_30_densenet161_val_acc_0.6990.pth",device)
        if model=='resnet50':
            models['resnet50'] = resnet50(num_classes=110)
            load_model(models['resnet50'],"../pre_weights/ep_41_resnet50_val_acc_0.6900.pth",device)
        if model=='incept_v3':
            models['incept_v3'] = inception_v3(num_classes=110)
            load_model(models['incept_v3'],"../pre_weights/ep_36_inception_v3_val_acc_0.6668.pth",device)
        if model=='incept_v1':
            models['incept_v1'] = googlenet(num_classes=110)
            load_model(models['incept_v1'],"../pre_weights/ep_33_googlenet_val_acc_0.7091.pth",device)
    #vgg16 = vgg16_bn(num_classes=110)
    #load_model(vgg16, "./pre_weights/ep_30_vgg16_bn_val_acc_0.7282.pth",device)
        if model=='incept_resnet_v2':
            models['incept_resnet_v2'] = InceptionResNetV2(num_classes=110)  
            load_model(models['incept_resnet_v2'], "../pre_weights/ep_17_InceptionResNetV2_ori_0.8320.pth",device)

        if model=='incept_v4':
            models['incept_v4'] = InceptionV4(num_classes=110)
            load_model(models['incept_v4'],"../pre_weights/ep_17_InceptionV4_ori_0.8171.pth",device)
        if model=='incept_resnet_v2_adv':
            models['incept_resnet_v2_adv'] = InceptionResNetV2(num_classes=110)  
            load_model(models['incept_resnet_v2_adv'], "../pre_weights/ep_22_InceptionResNetV2_val_acc_0.8214.pth",device)

        if model=='incept_v4_adv':
            models['incept_v4_adv'] = InceptionV4(num_classes=110)
            load_model(models['incept_v4_adv'],"../pre_weights/ep_24_InceptionV4_val_acc_0.6765.pth",device)
        if model=='incept_resnet_v2_adv2':
            models['incept_resnet_v2_adv2'] = InceptionResNetV2(num_classes=110)  
            #load_model(models['incept_resnet_v2_adv2'], "../test_weights/ep_29_InceptionResNetV2_adv2_0.8115.pth",device)
            load_model(models['incept_resnet_v2_adv2'], "../test_weights/ep_13_InceptionResNetV2_val_acc_0.8889.pth",device)

        if model=='incept_v4_adv2':
            models['incept_v4_adv2'] = InceptionV4(num_classes=110)
#            load_model(models['incept_v4_adv2'],"../test_weights/ep_32_InceptionV4_adv2_0.7579.pth",device)
            load_model(models['incept_v4_adv2'],"../test_weights/ep_50_InceptionV4_val_acc_0.8295.pth",device)

        if model=='resnet152':
            models['resnet152'] = resnet152(num_classes=110)
            load_model(models['resnet152'],"../pre_weights/ep_14_resnet152_ori_0.6956.pth",device)
        if model=='resnet152_adv':
            models['resnet152_adv'] = resnet152(num_classes=110)
            load_model(models['resnet152_adv'],"../pre_weights/ep_29_resnet152_adv_0.6939.pth",device)
        if model=='resnet152_adv2':
            models['resnet152_adv2'] = resnet152(num_classes=110)
            load_model(models['resnet152_adv2'],"../pre_weights/ep_31_resnet152_adv2_0.6931.pth",device)



        if model=='black_resnet50':
            models['black_resnet50'] = resnet50(num_classes=110)
            load_model(models['black_resnet50'],"../test_weights/ep_0_resnet50_val_acc_0.7063.pth",device)
        if model=='black_densenet161':
            models['black_densenet161'] = densenet161(num_classes=110)
            load_model(models['black_densenet161'],"../test_weights/ep_4_densenet161_val_acc_0.6892.pth",device)
        if model=='black_incept_v3':
            models['black_incept_v3']=inception_v3(num_classes=110)
            load_model(models['black_incept_v3'],"../test_weights/ep_28_inception_v3_val_acc_0.6680.pth",device)
        if model=='old_res':
            MainModel = imp.load_source('MainModel', "./models_old/tf_to_pytorch_resnet_v1_50.py")
            models['old_res'] = torch.load('./models_old/tf_to_pytorch_resnet_v1_50.pth').to(device)
        if model=='old_vgg':
            MainModel = imp.load_source('MainModel', "./models_old/tf_to_pytorch_vgg16.py")
            models[model] = torch.load('./models_old/tf_to_pytorch_vgg16.pth').to(device)
        if model=='old_incept':
            MainModel = imp.load_source('MainModel', "./models_old/tf_to_pytorch_inception_v1.py")
            models[model]  = torch.load('./models_old/tf_to_pytorch_inception_v1.pth').to(device)
       
    return models
    y_test = keras.utils.to_categorical(y_test, nb_classes)

    if args.model == 'resnet':
        model = resnet.resnet(nb_classes,
                              depth=args.depth,
                              wide_factor=args.wide_factor)
        save_name = 'resnet_{}_{}_{}'.format(args.depth, args.wide_factor,
                                             args.data)
    elif args.model == 'densenet':
        model = densenet.densenet(nb_classes,
                                  args.growth_rate,
                                  depth=args.depth)
        save_name = 'densenet_{}_{}_{}'.format(args.depth, args.growth_rate,
                                               args.data)
    elif args.model == 'inception':
        model = inception.inception_v3(nb_classes)
        save_name = 'inception_{}'.format(args.data)

    elif args.model == 'vgg':
        model = vggnet.vgg(nb_classes)
        save_name = 'vgg_{}'.format(args.data)
    else:
        raise ValueError('Does not support {}'.format(args.model))

    model.summary()

    weight_file = './results/' + save_name + '.h5'
    # decode
    weights = decode_weights(weight_file)
    for i in range(len(model.layers)):
        if model.layers[i].name in weights: