Exemple #1
0
def get_net_model(net='alexnet',
                  pretrained_dataset='imagenet',
                  dropout=False,
                  pretrained=True):
    if net == 'alexnet':
        model = myalexnet(pretrained=(pretrained_dataset == 'imagenet')
                          and pretrained,
                          dropout=dropout)
        teacher_model = alexnet(pretrained=(pretrained_dataset == 'imagenet'))
    elif net == 'mobilenet-imagenet':
        model = MobileNet(num_classes=1001, dropout=dropout)
        if pretrained and pretrained_dataset == 'imagenet':
            model.load_state_dict(torch.load(imagenet_pretrained_mbnet_path))
        teacher_model = MobileNet(num_classes=1001)
        if os.path.isfile(imagenet_pretrained_mbnet_path):
            teacher_model.load_state_dict(
                torch.load(imagenet_pretrained_mbnet_path))
        else:
            warnings.warn('failed to import teacher model!')
    elif net == 'erfnet-cityscapes':
        model = erfnet(pretrained=(pretrained_dataset == 'cityscapes')
                       and pretrained,
                       num_classes=20,
                       dropout=dropout)
        teacher_model = erfnet(pretrained=(pretrained_dataset == 'cityscapes'),
                               num_classes=20)
    else:
        raise NotImplementedError

    for p in teacher_model.parameters():
        p.requires_grad = False
    teacher_model.eval()

    return model, teacher_model
Exemple #2
0
def get_model_and_checkpoint(model, dataset, checkpoint_path, n_gpu=1):
    if model == 'mobilenet' and dataset == 'imagenet':
        from mobilenet import MobileNet
        net = MobileNet(n_class=1000)
    elif model == 'mobilenetv2' and dataset == 'imagenet':
        from mobilenet_v2 import MobileNetV2
        net = MobileNetV2(n_class=1000)
    elif model == 'mobilenet' and dataset == 'cifar10':
        from mobilenet import MobileNet
        net = MobileNet(n_class=10)
    elif model == 'mobilenetv2' and dataset == 'cifar10':
        from mobilenet_v2 import MobileNetV2
        net = MobileNetV2(n_class=10)
    else:
        raise NotImplementedError
    if checkpoint_path:
        print('loading {}...'.format(checkpoint_path))
        sd = torch.load(checkpoint_path, map_location=torch.device('cpu'))
        if 'state_dict' in sd:  # a checkpoint but not a state_dict
            sd = sd['state_dict']
        sd = {k.replace('module.', ''): v for k, v in sd.items()}
        net.load_state_dict(sd)

    if torch.cuda.is_available() and n_gpu > 0:
        net = net.cuda()
        if n_gpu > 1:
            net = torch.nn.DataParallel(net, range(n_gpu))

    return net
Exemple #3
0
def net_select(name, data_format='NCHW', weight_decay=0.0005):
    if name == 'ResNeXt':
        from resnext import ResNeXt
        network = ResNeXt(num_layers=50,
                          num_card=1,
                          data_format=data_format,
                          weight_decay=weight_decay)
    elif name == 'SENet':
        from senet import SENet
        network = SENet(num_layers=50,
                        num_card=1,
                        data_format=data_format,
                        weight_decay=weight_decay)
    elif name == 'MobileNet':
        from mobilenet import MobileNet
        network = MobileNet(alpha=1.0,
                            data_format=data_format,
                            weight_decay=weight_decay)
    elif name == 'ShuffleNet':
        from shufflenet import ShuffleNet
        network = MobileNet(num_groups=3,
                            alpha=1.0,
                            data_format=data_format,
                            weight_decay=weight_decay)
    elif name == 'SphereFace':
        from sphere import SphereFace
        network = SphereFace()
    else:
        raise ValueError('Unsupport network architecture.')

    return network
Exemple #4
0
def main():

    tf.set_random_seed(1234)

    image_batch = tf.constant(0, tf.float32, shape=[1, 713, 713, 3])
    net = MobileNet(image_batch, print_architecture=True)

    new_variables = [
        'MobileNet/conv_ds_15a', 'MobileNet/conv_ds_15b',
        'MobileNet/conv_ds_15c', 'MobileNet/conv_ds_15d',
        'MobileNet/conv_ds_16', 'MobileNet/conv_ds_17'
    ]
    restoreVar_mobilenet = slim.get_variables_to_restore(include=['MobileNet'],
                                                         exclude=new_variables)
    # restoreVar_mobilenet = [v for v in restoreVar_mobilenet if 'Momentum' not in v.name]
    newLayerVariables = slim.get_variables_to_restore(include=new_variables)
    otherLayerInitializer = tf.variables_initializer(newLayerVariables)

    var_list = tf.global_variables()

    # Set up tf session and initialize variables.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:

        loader = tf.train.Saver(var_list=restoreVar_mobilenet)
        loader.restore(sess, FLAGS.pretrained_mobilenet)

        sess.run(otherLayerInitializer)

        # Saver for converting the loaded weights into .ckpt.
        saver = tf.train.Saver(var_list=var_list)
        save(saver, sess, FLAGS.save_model)
Exemple #5
0
def objective(space):
    block_kernel1, block_stride1, block_kernel2, block_stride2, kernel_size1, stride1, learning_rate = space

    block_kernel1 = int(block_kernel1)
    block_stride1 = int(block_stride1)
    block_kernel2 = int(block_kernel2)
    block_stride2 = int(block_stride2)
    kernel_size1 = int(kernel_size1)
    stride1 = int(stride1)
    learning_rate = float(learning_rate)

    block = Block(in_planes=64, out_planes=64, block_kernel1=block_kernel1, block_stride1=block_stride1, 
                  block_kernel2=block_kernel2, block_stride2=block_stride2)
    net = MobileNet(block, kernel_size1=kernel_size1, stride1=stride1)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    if use_cuda:
        net.cuda()
        net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
        cudnn.benchmark = True

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)

    total_loss = 0.0
    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        train(epoch, net, optimizer, criterion)
        test_loss = test(epoch, net)
        total_loss += test_loss

    return total_loss / args.num_epochs
Exemple #6
0
def get_model_and_checkpoint(model, dataset, checkpoint_path, n_gpu=1):
    if dataset == 'imagenet':
        n_class = 1000
    elif dataset == 'cifar10':
        n_class = 10
    else:
        raise ValueError('unsupported dataset')

    if model == 'mobilenet':
        from mobilenet import MobileNet
        net = MobileNet(n_class=n_class)
    elif model == 'mobilenetv2':
        from mobilenet_v2 import MobileNetV2
        net = MobileNetV2(n_class=n_class)
    elif model.startswith('resnet'):
        net = resnet.__dict__[model](pretrained=True)
        in_features = net.fc.in_features
        net.fc = nn.Linear(in_features, n_class)
    else:
        raise NotImplementedError
    if checkpoint_path:
        print('loading {}...'.format(checkpoint_path))
        sd = torch.load(checkpoint_path, map_location=torch.device('cpu'))
        if 'state_dict' in sd:  # a checkpoint but not a state_dict
            sd = sd['state_dict']
        sd = {k.replace('module.', ''): v for k, v in sd.items()}
        net.load_state_dict(sd)

    if torch.cuda.is_available() and n_gpu > 0:
        net = net.cuda()
        if n_gpu > 1:
            net = torch.nn.DataParallel(net, range(n_gpu))

    return net
Exemple #7
0
def training_model(model_name='mobilenet'):
    train_gen, valid_gen, tconfig = get_gen_tconfig()
    callbacks = get_callbacks('mobilenet05_short_adam03_dr35_v3', patience=4)
    if model_name == 'mobilenet':
        print('MobileNet')
        model = MobileNet(config=tconfig, alpha=1.0)
        model.summary()
    elif model_name == 'mobilenet_dih':
        print('MobileNetDih')
        model = MobileNetDih4(config=tconfig, alpha=1)
        model.summary()
    elif model_name == 'mobilenet_dih_r':
        print('MobileNetDihR')
        model = MobileNetDR(config=tconfig, alpha=0.5)
        model.summary()
    opt = Adam(lr=1e-3, beta_1=0.9, beta_2=0.999)
    #opt = Adadelta(lr=1e-1, rho=0.95, decay=0.1)
    #opt = SGD(lr=1e-7, momentum=0.9, decay=0., nesterov=True)

    model.compile(optimizer=opt, loss='mse', metrics=['mae', 'mse'])
    #model.load_weights('mobilenet_05shortd01_catcros_resize_b16.hdf5')
    model.fit_generator(generator=train_gen,
                        steps_per_epoch=1000,
                        epochs=40,
                        validation_data=valid_gen,
                        verbose=2,
                        validation_steps=500,
                        callbacks=callbacks)
    #opt = Adam(lr=1e-3, beta_1=0.9, beta_2=0.999)
    #opt = Adadelta(lr=1e-1, rho=0.95, decay=0.1)
    """
Exemple #8
0
    def quan(self, config_file):
        if not fluid.core.is_compiled_with_cuda():
            return
        class_dim = 10
        image_shape = [1, 28, 28]

        train_program = fluid.Program()
        startup_program = fluid.Program()
        val_program = fluid.Program()

        with fluid.program_guard(train_program, startup_program):
            with fluid.unique_name.guard():
                image = fluid.layers.data(name='image',
                                          shape=image_shape,
                                          dtype='float32')
                image.stop_gradient = False
                label = fluid.layers.data(name='label',
                                          shape=[1],
                                          dtype='int64')
                out = MobileNet(name='quan').net(input=image,
                                                 class_dim=class_dim)
                print("out: {}".format(out.name))
                acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
                acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
                cost = fluid.layers.cross_entropy(input=out, label=label)
                avg_cost = fluid.layers.mean(x=cost)
        optimizer = fluid.optimizer.Momentum(
            momentum=0.9,
            learning_rate=0.01,
            regularization=fluid.regularizer.L2Decay(4e-5))

        val_program = train_program.clone(for_test=False)

        place = fluid.CUDAPlace(0)
        exe = fluid.Executor(place)
        exe.run(startup_program)

        val_reader = self.set_val_reader(image, label, place)

        val_feed_list = self.set_feed_list(image, label)
        val_fetch_list = [('acc_top1', acc_top1.name),
                          ('acc_top5', acc_top5.name)]

        train_reader = self.set_train_reader(image, label, place)
        train_feed_list = self.set_feed_list(image, label)
        train_fetch_list = [('loss', avg_cost.name)]

        com_pass = Compressor(place,
                              fluid.global_scope(),
                              train_program,
                              train_reader=train_reader,
                              train_feed_list=train_feed_list,
                              train_fetch_list=train_fetch_list,
                              eval_program=val_program,
                              eval_reader=val_reader,
                              eval_feed_list=val_feed_list,
                              eval_fetch_list=val_fetch_list,
                              train_optimizer=optimizer)
        com_pass.config(config_file)
        eval_graph = com_pass.run()
Exemple #9
0
def get_model(args):
    print('=> Building model..')

    if args.dataset == 'imagenet':
        n_class = 1000
    elif args.dataset == 'cifar10':
        n_class = 10
    else:
        raise NotImplementedError

    if args.model_type == 'mobilenet':
        net = MobileNet(n_class=n_class)
    elif args.model_type == 'mobilenetv2':
        net = MobileNetV2(n_class=n_class)
    elif args.model_type.startswith('resnet'):
        net = resnet.__dict__[args.model_type](pretrained=True)
        in_features = net.fc.in_features
        net.fc = nn.Linear(in_features, n_class)
    else:
        raise NotImplementedError

    if args.ckpt_path is not None:
        # the checkpoint can be state_dict exported by amc_search.py or saved by amc_train.py
        print('=> Loading checkpoint {} ..'.format(args.ckpt_path))
        net.load_state_dict(torch.load(args.ckpt_path, torch.device('cpu')))
        if args.mask_path is not None:
            SZ = 224 if args.dataset == 'imagenet' else 32
            data = torch.randn(2, 3, SZ, SZ)
            ms = ModelSpeedup(net, data, args.mask_path, torch.device('cpu'))
            ms.speedup_model()

    net.to(args.device)
    if torch.cuda.is_available() and args.n_gpu > 1:
        net = torch.nn.DataParallel(net, list(range(args.n_gpu)))
    return net
Exemple #10
0
def create_model(model_type=None,
                 n_classes=120,
                 input_size=224,
                 checkpoint=None,
                 pretrained=False,
                 width_mult=1.):
    if model_type == 'mobilenet_v1':
        model = MobileNet(n_class=n_classes, profile='normal')
    elif model_type == 'mobilenet_v2':
        model = MobileNetV2(n_class=n_classes,
                            input_size=input_size,
                            width_mult=width_mult)
    elif model_type == 'mobilenet_v2_torchhub':
        model = torch.hub.load('pytorch/vision:v0.8.1',
                               'mobilenet_v2',
                               pretrained=pretrained)
        # model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=pretrained)
        feature_size = model.classifier[1].weight.data.size()[1]
        replace_classifier = torch.nn.Linear(feature_size, n_classes)
        model.classifier[1] = replace_classifier
    elif model_type is None:
        model = None
    else:
        raise RuntimeError('Unknown model_type.')

    if checkpoint is not None:
        model.load_state_dict(torch.load(checkpoint))

    return model
Exemple #11
0
def load_model (args):

	if args.model == 'inception':
		model = InceptionV3(include_top=True, weights='imagenet')
		preprocess_mode='tf'
	elif args.model == 'xception':
		model = Xception(include_top=True, weights='imagenet')
		preprocess_mode='tf'
	elif args.model == 'inceptionresnet':
		model = InceptionResNetV2(include_top=True, weights='imagenet')
		preprocess_mode='tf'
	elif args.model == 'mobilenet':
		model = MobileNet(include_top=True, weights='imagenet')
		preprocess_mode='tf'
	elif args.model == 'mobilenet2':	
		model = MobileNetV2(include_top=True, weights='imagenet')
		preprocess_mode='tf'
	elif args.model == 'nasnet':	
		model = NASNetLarge(include_top=True, weights='imagenet')
		preprocess_mode='tf'
	elif args.model == 'resnet':
		model = ResNet50(include_top=True, weights='imagenet')
		preprocess_mode='caffe'
	elif args.model == 'vgg16':
		model = VGG16(include_top=True, weights='imagenet')
		preprocess_mode='caffe'
	elif args.model == 'vgg19':
		model = VGG19(include_top=True, weights='imagenet')
		preprocess_mode='caffe'
	else:
		print ("Model not found")

	return model,preprocess_mode
Exemple #12
0
def get_model(args):
    print('=> Building model..')

    if args.dataset == 'imagenet':
        n_class = 1000
    elif args.dataset == 'cifar10':
        n_class = 10
    else:
        raise NotImplementedError

    if args.model_type == 'mobilenet':
        net = MobileNet(n_class=n_class).cuda()
    elif args.model_type == 'mobilenetv2':
        net = MobileNetV2(n_class=n_class).cuda()
    else:
        raise NotImplementedError

    if args.ckpt_path is not None:
        # the checkpoint can be a saved whole model object exported by amc_search.py, or a state_dict
        print('=> Loading checkpoint {} ..'.format(args.ckpt_path))
        ckpt = torch.load(args.ckpt_path)
        if type(ckpt) == dict:
            net.load_state_dict(ckpt['state_dict'])
        else:
            net = ckpt

    net.to(args.device)
    if torch.cuda.is_available() and args.n_gpu > 1:
        net = torch.nn.DataParallel(net, list(range(args.n_gpu)))
    return net
Exemple #13
0
    def test_compression(self):
        """
        Model: mobilenet_v1
        data: mnist
        step1: Training one epoch
        step2: pruning flops
        step3: fine-tune one epoch
        step4: check top1_acc.
        """
        if not fluid.core.is_compiled_with_cuda():
            return
        class_dim = 10
        image_shape = [1, 28, 28]
        image = fluid.layers.data(name='image',
                                  shape=image_shape,
                                  dtype='float32')
        image.stop_gradient = False
        label = fluid.layers.data(name='label', shape=[1], dtype='int64')
        out = MobileNet("auto_pruning").net(input=image, class_dim=class_dim)
        acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
        acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
        val_program = fluid.default_main_program().clone(for_test=False)

        cost = fluid.layers.cross_entropy(input=out, label=label)
        avg_cost = fluid.layers.mean(x=cost)

        optimizer = fluid.optimizer.Momentum(
            momentum=0.9,
            learning_rate=0.01,
            regularization=fluid.regularizer.L2Decay(4e-5))

        place = fluid.CUDAPlace(0)
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())

        val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128)

        val_feed_list = [('img', image.name), ('label', label.name)]
        val_fetch_list = [('acc_top1', acc_top1.name),
                          ('acc_top5', acc_top5.name)]

        train_reader = paddle.batch(paddle.dataset.mnist.train(),
                                    batch_size=128)
        train_feed_list = [('img', image.name), ('label', label.name)]
        train_fetch_list = [('loss', avg_cost.name)]

        com_pass = Compressor(place,
                              fluid.global_scope(),
                              fluid.default_main_program(),
                              train_reader=train_reader,
                              train_feed_list=train_feed_list,
                              train_fetch_list=train_fetch_list,
                              eval_program=val_program,
                              eval_reader=val_reader,
                              eval_feed_list=val_feed_list,
                              eval_fetch_list=val_fetch_list,
                              train_optimizer=optimizer)
        com_pass.config('./auto_pruning/compress.yaml')
        eval_graph = com_pass.run()
Exemple #14
0
 def __init__(self, opt):
     super(KeypointModel, self).__init__(opt)
     self.pretrained = MobileNet()
     self.trf = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1),
                              nn.BatchNorm2d(256), nn.ReLU(True),
                              nn.Conv2d(256, 128, 3, 1, 1),
                              nn.BatchNorm2d(128), nn.ReLU(True))
     # self.ReturnType = namedtuple('ReturnType',['out1','out2','out3','out4','out5','out6'])
     stages = [Stage(128)] + [Stage(169) for _ in range(2, 7)]
     self.stages = nn.ModuleList(stages)
Exemple #15
0
def training_model(model_name='mobilenet'):
    train_img, valid_img, train_y, valid_y = get_data()
    callbacks = get_callbacks('mobilenet_10fulld01_b16', patience=2)
    if model_name == 'mobilenet':
        print('MobileNet')
        model = MobileNet(alpha=1.)
        model.summary()
    elif model_name == 'mobilenet_dih':
        print('MobileNetDih')
        model = MobileNetDih4(alpha=1.)
        model.summary()
    elif model_name == 'mobilenet_dih_r':
        print('MobileNetDihR')
        model = MobileNetDR(alpha=1.)
        model.summary()

    opt = Adam(lr=1e-3, beta_1=0.9, beta_2=0.999)
    #opt = Adadelta(lr=1e-1, rho=0.95, decay=0.1)
    #opt = SGD(lr=1e-7, momentum=0.9, decay=0., nesterov=True)

    model.compile(optimizer=opt,
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    #model.load_weights('mobilenet_05shortd01_catcros_resize_b16.hdf5')
    gen = ImageDataGenerator(rotation_range=359,
                             zoom_range=[0.5, 2],
                             width_shift_range=0.1,
                             height_shift_range=0.1,
                             vertical_flip=True,
                             horizontal_flip=True)

    model.fit_generator(
        gen.flow(np.array(train_img), np.array(train_y),
                 batch_size=BATCH_SIZE),
        steps_per_epoch=16 * len(train_y) // BATCH_SIZE,
        epochs=40,
        validation_data=[np.array(valid_img),
                         np.array(valid_y)],
        verbose=1,
        callbacks=callbacks)
    #    """
    #opt = Adam(lr=1e-3, beta_1=0.9, beta_2=0.999)
    #opt = Adadelta(lr=1e-1, rho=0.95, decay=0.1)
    opt = SGD(lr=0.05, momentum=0.9, decay=0., nesterov=True)
    model.load_weights('mobilenet_10shortd01_b16_sgd')
    model.fit_generator(
        gen.flow(np.array(train_img), np.array(train_y),
                 batch_size=BATCH_SIZE),
        steps_per_epoch=16 * len(train_y) // BATCH_SIZE,
        epochs=10,
        validation_data=[np.array(valid_img),
                         np.array(valid_y)],
        verbose=1,
        callbacks=callbacks)
Exemple #16
0
def _get_mobilenet_features(image, mode, load_weights=False, alpha=1):
    from mobilenet import MobileNet
    training = mode == tf.estimator.ModeKeys.TRAIN
    tf.keras.backend.set_learning_phase(training)
    weights = 'imagenet' if load_weights else None

    model = MobileNet(input_shape=image.shape.as_list()[1:],
                      input_tensor=image,
                      include_top=False,
                      weights=weights,
                      alpha=alpha)
    return model.output
Exemple #17
0
def make(tflite = False):
    "モデルを作成する"
    if tflite:
        # TensorFlow Lite用に改造したMobileNet
        return MobileNet(
            input_shape=(224,224,3),
            alpha=0.5,weights=None, classes=101)
    else:
        # TensorFlow標準のMobileNet
        return tf.keras.applications.MobileNet(
            input_shape=(224,224,3),
            alpha=0.5,weights=None, classes=101)
Exemple #18
0
def main():
    test_patterns = [
        ('VGGNetBN', VGGNetBN(17), 224), ('VGGNetBNHalf', VGGNetBN(17,
                                                                   32), 224),
        ('VGGNetBNQuater', VGGNetBN(17, 16), 224),
        ('GoogLeNetBN', GoogLeNetBN(17), 224),
        ('GoogLeNetBNHalf', GoogLeNetBN(17, 16), 224),
        ('GoogLeNetBNQuater', GoogLeNetBN(17, 8), 224),
        ('ResNet50', ResNet50(17), 224), ('ResNet50Half', ResNet50(17,
                                                                   32), 224),
        ('ResNet50Quater', ResNet50(17, 16), 224),
        ('SqueezeNet', SqueezeNet(17), 224),
        ('SqueezeNetHalf', SqueezeNet(17, 8), 224),
        ('MobileNet', MobileNet(17), 224),
        ('MobileNetHalf', MobileNet(17, 16), 224),
        ('MobileNetQuater', MobileNet(17, 8), 224),
        ('InceptionV4', InceptionV4(dim_out=17), 299),
        ('InceptionV4S',
         InceptionV4(dim_out=17,
                     base_filter_num=6,
                     ablocks=2,
                     bblocks=1,
                     cblocks=1), 299),
        ('InceptionResNetV2', InceptionResNetV2(dim_out=17), 299),
        ('InceptionResNetV2S',
         InceptionResNetV2(dim_out=17,
                           base_filter_num=8,
                           ablocks=1,
                           bblocks=2,
                           cblocks=1), 299),
        ('FaceClassifier100x100V', FaceClassifier100x100V(17), 100),
        ('FaceClassifier100x100V2', FaceClassifier100x100V2(17), 100)
    ]

    for model_name, model, test_size in test_patterns:
        oltp_cpu, batch_gpu = check_speed(model, test_images[test_size])
        print('{}\t{:.02f}\t{:.02f}'.format(model_name, oltp_cpu * 1000,
                                            batch_gpu * 1000))
Exemple #19
0
def main(_):
    tf.logging.set_verbosity(tf.logging.DEBUG)

    with tf.variable_scope('data'):
        train_images, train_labels, num_classes = dog_tensor(
            FLAGS.dogdir, FLAGS.batch_size, class_regex=FLAGS.dog_regex)

    net = MobileNet(num_classes, alpha=FLAGS.alpha)

    train_logits = net(train_images, is_training=True)
    tf.logging.info('built model on training data')
    param_stats = tfprof.model_analyzer.print_model_analysis(
        tf.get_default_graph(),
        tfprof_options=tfprof.model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS
    )

    with tf.variable_scope('training'):
        loss = tf.losses.sparse_softmax_cross_entropy(labels=train_labels,
                                                      logits=train_logits)
        loss = tf.reduce_mean(loss)
        tf.summary.scalar('train/xent', loss)

        global_step = tf.train.get_or_create_global_step()

        opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
        train_step = opt.minimize(loss, global_step=global_step)
        # make sure we update running averages
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        if update_ops:
            train_step = tf.group(train_step, *update_ops)

    # valid/test

    sv = tf.train.Supervisor(logdir=FLAGS.logdir,
                             global_step=global_step,
                             save_summaries_secs=15)

    with sv.managed_session() as sess, sv.stop_on_exception():
        tf.logging.debug('ready to run things')

        # sess = tfdbg.LocalCLIDebugWrapperSession(sess)

        step = sess.run(global_step)
        while step < FLAGS.max_steps:
            step, train_loss, _ = sess.run([global_step, loss, train_step])
            tf.logging.info('(%d) train loss: %f', step, train_loss)
Exemple #20
0
def run_training(config,
                 n_classes,
                 train_loader,
                 valid_loader,
                 width=1,
                 mb_version=1):
    """
    Whole training procedure with fine-tune after regular training
    """
    # defining model
    if width > 1:
        model = tvm.resnet18(num_classes=n_classes)
    else:
        if mb_version == 1:
            model = MobileNet(n_classes=n_classes, width_mult=width)
        else:
            model = MobileNetV2(n_classes=n_classes, width_mult=width)
    model = model.to(config['device'])

    # print out number of parameters
    num_params = 0
    for p in model.parameters():
        num_params += np.prod(p.size())
    print(f"width={width}, num_params {num_params}")

    # defining loss criterion, optimizer and learning rate scheduler
    criterion = t.nn.CrossEntropyLoss()
    opt = t.optim.Adam(model.parameters(), config['lr'])
    sched = t.optim.lr_scheduler.MultiStepLR(opt, [3, 6])

    # training process with Adam
    tr_loss, tr_accuracy, valid_loss, valid_accuracy = train(
        config, model, train_loader, valid_loader, criterion, opt, sched)
    # training process with SGDR
    opt = t.optim.SGD(model.parameters(), config['lr'] / 10, momentum=0.9)
    sched = SGDR(opt, 3, 1.2)
    tr_loss_finetune, tr_accuracy_finetune, valid_loss_finetune, valid_accuracy_finetune = train(
        config, model, train_loader, valid_loader, criterion, opt, sched)
    return [
        tr_loss + tr_loss_finetune, tr_accuracy + tr_accuracy_finetune,
        valid_loss + valid_loss_finetune,
        valid_accuracy + valid_accuracy_finetune
    ]
Exemple #21
0
def train():
    tr_config = {
        'flag': True,
        'rg': 25,  # 7, 5
        'wrg': 0.25,  # 1, 3
        'hrg': 0.25,  # 1, 3
        'zoom': 0.25  # 1, 1
    }
    callbacks = get_callbacks('mynet_v4_bias', patience=30)

    paths, y = search_file('set1/segmented_set1')
    paths, y = search_file('set2/segmented_set2', paths=paths, y=y)

    ds = DataSet(nframe=30,
                 fstride=6,
                 name='UT interaction',
                 size=[224, 224, 3],
                 filepaths=paths,
                 y=y,
                 kernel_size=4)
    ds.make_set(op='msqr', name='train')
    ds.make_set(op='msqr', name='valid')

    #opt = Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, decay=0.1)
    #opt = SGD(lr=2*1e-1, momentum=0.9, nesterov=True, decay=0.2)
    opt = RMSprop(lr=0.001, rho=0.9, decay=0.01)

    model = MobileNet(alpha=1.0, shape=[29, 56, 56, 1], nframe=29)
    model.compile(optimizer=opt,
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    model.summary()

    #model.load_weights('mynet_v4.h5')
    model.fit_generator(generator=ds.train_gen(batch_size=5,
                                               aug_config=tr_config),
                        steps_per_epoch=100,
                        epochs=300,
                        validation_data=ds.valid_gen(),
                        verbose=1,
                        validation_steps=ds.getVlen,
                        callbacks=callbacks)
Exemple #22
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    if os.path.exists(DICO_PKL):
        with open(DICO_PKL, 'rb') as f:
            word_to_id, id_to_word = pickle.load(f)
    else:
        word_to_id, id_to_word = create_dico(DICO)
        with open(DICO_PKL, 'wb') as f:
            pickle.dump([word_to_id, id_to_word], f)

    vocab_size = len(word_to_id)

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, True)
    mobilenet = MobileNet(BATCH_SIZE)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    # generator.load_weight()
    mobilenet.load_pretrained_weights(sess)
    sess.run(tf.global_variables_initializer())  

    im = Image.open(IMAGE).convert('RGB')
    im = im.resize((224, 224))
    im = np.array(im)
    im = np.expand_dims(im, 0)
    feed_dict = {
                    mobilenet.X: im,
                    mobilenet.is_training: False 
                }
    hidden_batch = sess.run(mobilenet.y_output, feed_dict=feed_dict)
    samples = generator.generate(sess, hidden_batch)
    y = samples.tolist()
    for k, sam in enumerate(y):
        sa = [id_to_word[i] for i in sam]
        sa = ''.join(sa)
        print(sa)
Exemple #23
0
def get_network(args, use_gpu=True):
    """ return given network
    """

    if args.net == 'resnet18':
        net = ResNet18()
    elif args.net == 'resnetcbam18':
        net = ResNetCBAM18()
    elif args.net == 'resnetzam18':
        net = ResNetZAM18()
    elif args.net == 'mobilenet':
        net = MobileNet()
    elif args.net == 'mobilenetcbam':
        net = MobileNetCBAM()
    elif args.net == 'mobilenetzam':
        net = MobileNetZAM()
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()
    
    if use_gpu:
        net = net.cuda()

    return net
Exemple #24
0
def train():
    height = args.height
    width = args.width
    _step = 0

    if True:
        #glob_pattern = os.path.join(args.dataset_dir,"*_train.tfrecord")
        #tfrecords_list = glob.glob(glob_pattern)
        #filename_queue = tf.train.string_input_producer(tfrecords_list, num_epochs=None)
        img_batch, label_batch = get_batch("cifar10/cifar10_train.tfrecord",
                                           args.batch_size,
                                           shuffle=True)

        mobilenet = MobileNet(img_batch, num_classes=args.num_classes)
        logits = mobilenet.logits
        pred = mobilenet.predictions

        cross = tf.nn.softmax_cross_entropy_with_logits(labels=label_batch,
                                                        logits=logits)
        loss = tf.reduce_mean(cross)

        # L2 regularization
        list_reg = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        if len(list_reg) > 0:
            l2_loss = tf.add_n(list_reg)
            total_loss = loss + l2_loss
        else:
            total_loss = loss

        # evaluate model, for classification
        preds = tf.argmax(pred, 1)
        labels = tf.argmax(label_batch, 1)
        #correct_pred = tf.equal(tf.argmax(pred, 1), tf.cast(label_batch, tf.int64))
        correct_pred = tf.equal(preds, labels)
        acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

        # learning rate decay
        base_lr = tf.constant(args.learning_rate)
        global_step = tf.Variable(0)
        lr = tf.train.exponential_decay(args.learning_rate,
                                        global_step=global_step,
                                        decay_steps=args.lr_decay_step,
                                        decay_rate=args.lr_decay,
                                        staircase=True)

        # optimizer
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = tf.train.AdamOptimizer(learning_rate=lr,
                                              beta1=args.beta1).minimize(
                                                  loss,
                                                  global_step=global_step)

        max_steps = int(args.num_samples / int(args.batch_size) *
                        int(args.epoch))

        # summary
        tf.summary.scalar('total_loss', total_loss)
        tf.summary.scalar('accuracy', acc)
        tf.summary.scalar('learning_rate', lr)
        summary_op = tf.summary.merge_all()

        with tf.Session() as sess:

            # summary writer
            writer = tf.summary.FileWriter(args.logs_dir, sess.graph)

            sess.run(tf.global_variables_initializer())

            saver = tf.train.Saver()
            _, _step = load(sess, saver, args.checkpoint_dir)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            for step in range(_step + 1, max_steps + 1):

                start_time = time.time()

                _, _lr = sess.run([train_op, lr])

                if step % args.num_log == 0:
                    summ, _loss, _acc = sess.run([summary_op, total_loss, acc])
                    writer.add_summary(summ, step)
                    print(
                        'number to eval:{0}, time:{1:.3f}, lr:{2:.8f}, acc:{3:.6f}, loss:{4:.6f}'
                        .format(step * args.batch_size,
                                time.time() - start_time, _lr, _acc, _loss))

                if step % args.num_log == 0:
                    save_path = saver.save(sess,
                                           os.path.join(
                                               args.checkpoint_dir,
                                               args.model_name),
                                           global_step=step)

                if step % 100 == 0:
                    totalloss = 0.0
                    totalacc = 0.0
                    for e_step in range(200):
                        _loss, _acc = sess.run([total_loss, acc])
                        totalloss = totalloss + _loss
                        totalacc = totalacc + _acc

                    print('global_step:%g, time:%g, t acc:%g, t loss:%g' %
                          ((e_step + 1) * args.batch_size,
                           time.time() - start_time, totalacc /
                           (e_step + 1), totalloss / (e_step + 1)))

            tf.train.write_graph(sess.graph_def, args.checkpoint_dir,
                                 args.model_name + '.pb')
            save_path = saver.save(sess,
                                   os.path.join(args.checkpoint_dir,
                                                args.model_name),
                                   global_step=max_steps)

            coord.request_stop()
            coord.join(threads)
Exemple #25
0
    def __init__(self, num_classes):
        super(SSD, self).__init__()
        self.num_classes = num_classes

        # Setup the backbone network (base_net)
        self.base_net = MobileNet(num_classes)

        # The feature map will extracted from layer[11] and layer[13] in (base_net)
        self.base_output_layer_indices = (11, 13)

        # Define the Additional feature extractor
        self.additional_feat_extractor = nn.ModuleList([
            # Conv8_2 : 256 x 5 x 5
            nn.Sequential(
                nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=256,
                          out_channels=512,
                          kernel_size=3,
                          stride=2,
                          padding=1), nn.ReLU()),
            # Conv9_2 : 256 x 3 x 3
            nn.Sequential(
                nn.Conv2d(in_channels=512, out_channels=128, kernel_size=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=128,
                          out_channels=256,
                          kernel_size=3,
                          stride=2,
                          padding=1), nn.ReLU()),
            # TODO: implement two more layers.
            # Conv10_2: 256 x 2 x 2
            nn.Sequential(
                nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=128,
                          out_channels=256,
                          kernel_size=3,
                          stride=2,
                          padding=1), nn.ReLU()),
            # Conv11_2: 256 x 1 x 1
            nn.Sequential(
                nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=128,
                          out_channels=256,
                          kernel_size=3,
                          stride=2,
                          padding=1), nn.ReLU())
        ])

        # Bounding box offset regressor
        num_prior_bbox = 6  # num of prior bounding boxes
        self.loc_regressor = nn.ModuleList([
            nn.Conv2d(in_channels=512,
                      out_channels=(num_prior_bbox * 4),
                      kernel_size=3,
                      padding=1),
            nn.Conv2d(in_channels=1024,
                      out_channels=(num_prior_bbox * 4),
                      kernel_size=3,
                      padding=1),
            nn.Conv2d(in_channels=512,
                      out_channels=(num_prior_bbox * 4),
                      kernel_size=3,
                      padding=1),
            # TODO: implement remaining layers.
            nn.Conv2d(in_channels=256,
                      out_channels=(num_prior_bbox * 4),
                      kernel_size=3,
                      padding=1),
            nn.Conv2d(in_channels=256,
                      out_channels=(num_prior_bbox * 4),
                      kernel_size=3,
                      padding=1),
            nn.Conv2d(in_channels=256,
                      out_channels=(num_prior_bbox * 4),
                      kernel_size=3,
                      padding=1),
        ])

        # Bounding box classification confidence for each label
        self.classifier = nn.ModuleList([
            nn.Conv2d(in_channels=512,
                      out_channels=num_prior_bbox * num_classes,
                      kernel_size=3,
                      padding=1),
            nn.Conv2d(in_channels=1024,
                      out_channels=num_prior_bbox * num_classes,
                      kernel_size=3,
                      padding=1),
            nn.Conv2d(in_channels=512,
                      out_channels=num_prior_bbox * num_classes,
                      kernel_size=3,
                      padding=1),
            # TODO: implement remaining layers.
            nn.Conv2d(in_channels=256,
                      out_channels=(num_prior_bbox * num_classes),
                      kernel_size=3,
                      padding=1),
            nn.Conv2d(in_channels=256,
                      out_channels=(num_prior_bbox * num_classes),
                      kernel_size=3,
                      padding=1),
            nn.Conv2d(in_channels=256,
                      out_channels=(num_prior_bbox * num_classes),
                      kernel_size=3,
                      padding=1),
        ])

        # Todo: load the pre-trained model for self.base_net, it will increase the accuracy by fine-tuning
        def init_pretrained_weights(net_dict, pretrained_dict):
            ext_keys = []
            new_keys = []
            del_keys = []
            for key in pretrained_dict.keys():
                # change key names
                if key.find('base_net') > -1:
                    ext_keys.append(key)
                    new_keys.append('conv_layers' + key[len('base_net'):])
                # discard parameters not in mobilenet
                if key not in net_dict.keys():
                    del_keys.append(key)
            #copy value from ext_keys to new_keys
            for idx in range(len(ext_keys)):
                pretrained_dict[new_keys[idx]] = pretrained_dict[ext_keys[idx]]
            #delete unmatched keys
            for key in del_keys:
                pretrained_dict.pop(key)

            # add undefined name (FC is not used in our model, just initialize with default)
            for key in net_dict.keys():
                if key not in pretrained_dict.keys():
                    pretrained_dict[key] = net_dict[key]

            return pretrained_dict

        model_dict = self.state_dict()
        pretrained_dict = torch.load('./pretrained/mobienetv2.pth')
        pretrained_weights = init_pretrained_weights(model_dict,
                                                     pretrained_dict)
        model_dict.update(pretrained_weights)

        def init_with_xavier(m):
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)

        self.loc_regressor.apply(init_with_xavier)
        self.classifier.apply(init_with_xavier)
        self.additional_feat_extractor.apply(init_with_xavier)
Exemple #26
0
def _model_fn(num_bits, features, labels, mode, params):

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    # weight reguralization
    regularizer = tf.contrib.layers.l2_regularizer(scale=config.weight_decay)
    # create model
    num_classes = 10
    model = MobileNet(num_classes,
                      is_training,
                      num_bits,
                      width_multiplier=config.width_multiplier,
                      quant_mode=config.quant_method,
                      conv2d_regularizer=regularizer)

    # forward pass
    logits = model.forward_pass(features)
    predict_class = tf.argmax(input=logits, axis=1)
    #predict_class = tf.Print(predict_class, [predict_class])
    predictions = {
        'classes': predict_class,
        'probabilities': tf.nn.softmax(logits)
    }

    # calculate accuracy
    accuracy = tf.metrics.accuracy(labels, predictions['classes'])
    metrics = {'accuracy': accuracy}

    # loss function
    loss = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels)

    # reguralization loss
    reg_variables = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    reg_term = tf.contrib.layers.apply_regularization(regularizer,
                                                      reg_variables)
    loss += reg_term

    if mode == tf.estimator.ModeKeys.TRAIN:

        # add fake_quant to 'normal' graph
        if config.quant_method == 'tensorflow':
            print("TF quantize create training graph")
            g = tf.get_default_graph()
            tf.contrib.quantize.create_training_graph(input_graph=g,
                                                      quant_delay=0)

        # learning rate decay
        global_step = tf.train.get_global_step()
        steps_per_epoch = num_training_per_epoch / config.train_batch_size
        decay_steps = steps_per_epoch * config.decay_per_epoch
        decay_rate = config.decay_rate

        learning_rate = tf.train.exponential_decay(config.learning_rate,
                                                   global_step, decay_steps,
                                                   decay_rate)

        learning_rate = tf.maximum(learning_rate, config.learning_rate * 0.01)
        # optimize loss
        optimizer = tf.train.AdamOptimizer(learning_rate)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(
                loss=loss, global_step=tf.train.get_global_step())

        # logging
        tf.summary.scalar("accuracy", accuracy[1])
        tf.summary.scalar("learning_rate", learning_rate)

        # printing
        tensors_to_log = {
            'learning_rate': learning_rate,
            'loss': loss,
            'accuracy': accuracy[1]
        }

        train_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log,
                                                every_n_iter=1000)

        return tf.estimator.EstimatorSpec(mode=mode,
                                          predictions=predictions,
                                          loss=loss,
                                          train_op=train_op,
                                          training_hooks=[train_hook],
                                          eval_metric_ops=metrics)

    elif mode == tf.estimator.ModeKeys.EVAL:
        if config.quant_method == 'tensorflow':
            g = tf.get_default_graph()
            tf.contrib.quantize.create_eval_graph(input_graph=g)

        tf.summary.scalar("accuracy", accuracy[1])
        eval_tensors_to_log = {'eval_loss': loss, 'eval_accuracy': accuracy[1]}
        evaluation_hook = tf.train.LoggingTensorHook(
            tensors=eval_tensors_to_log, every_n_iter=1000)

        return tf.estimator.EstimatorSpec(mode=mode,
                                          predictions=predictions,
                                          loss=loss,
                                          evaluation_hooks=[evaluation_hook],
                                          eval_metric_ops=metrics)
Exemple #27
0
#                神兽保佑
#                BUG是不可能有BUG的!
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

from mobilenet import MobileNet
import numpy as np
np.random.seed(10)
from keras.datasets import cifar10
from keras.utils import np_utils

(x_img_train, y_label_train),(x_img_test, y_label_test) = cifar10.load_data()

x_img_train = x_img_train.astype('float')/255.0
x_img_test = x_img_test.astype('float')/255.0

y_label_train = np_utils.to_categorical(y_label_train)
y_label_test = np_utils.to_categorical(y_label_test)

model = MobileNet()
try:
    model.load_weights("mobileV1-lite.h5")
    print("模型加载成功!继续训练")
except:
    print("模型加载失败!从头开始训练")

model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
train_history = model.fit(x_img_train, y_label_train, validation_split=0.2, epochs=10, batch_size=128, verbose=2)
model.save_weights("mobileV1-lite.h5")
print("保存模型成功!")
Exemple #28
0
    def __init__(self, num_classes):
        super(SSD, self).__init__()
        self.num_classes = num_classes
        # Setup the backbone network (base_net)
        self.base_net = MobileNet(num_classes)
        # The feature map will extracted from layer[11] and layer[13] in (base_net)
        self.base_output_layer_indices = (11, 13)
        # Define the Additional feature extractor
        self.additional_feat_extractor = nn.ModuleList([
            # Conv8_2
            nn.Sequential(
                nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=256,
                          out_channels=512,
                          kernel_size=3,
                          stride=2,
                          padding=1), nn.ReLU()),
            # Conv9_2
            nn.Sequential(
                nn.Conv2d(in_channels=512, out_channels=128, kernel_size=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=128,
                          out_channels=256,
                          kernel_size=3,
                          stride=2,
                          padding=1), nn.ReLU()),
            # Conv10_2
            nn.Sequential(
                nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=128,
                          out_channels=256,
                          kernel_size=3,
                          stride=1,
                          padding=1),
                nn.ReLU(),
            ),
            # Conv11_2
            nn.Sequential(
                nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=128,
                          out_channels=256,
                          kernel_size=3,
                          stride=1),
                nn.ReLU(),
            ),
        ])

        # Bounding box offset regressor
        num_prior_bbox = 6  # num of prior bounding boxes
        self.loc_regressor = nn.ModuleList([
            nn.Conv2d(in_channels=512,
                      out_channels=num_prior_bbox * 4,
                      kernel_size=3,
                      padding=1),  #Cov5_3
            nn.Conv2d(in_channels=1024,
                      out_channels=num_prior_bbox * 4,
                      kernel_size=3,
                      padding=1),  #FC7
            nn.Conv2d(in_channels=512,
                      out_channels=num_prior_bbox * 4,
                      kernel_size=3,
                      padding=1),  #Conv8_2
            # TODO: implement remaining layers.
            nn.Conv2d(in_channels=256,
                      out_channels=num_prior_bbox * 4,
                      kernel_size=3,
                      padding=1),  #Conv9_2
            nn.Conv2d(in_channels=256,
                      out_channels=num_prior_bbox * 4,
                      kernel_size=3,
                      padding=1),  #Conv10_2
            nn.Conv2d(in_channels=256,
                      out_channels=num_prior_bbox * 4,
                      kernel_size=3,
                      padding=1),  #Conv11_2
        ])

        # Bounding box classification confidence for each label
        self.classifier = nn.ModuleList([
            nn.Conv2d(in_channels=512,
                      out_channels=num_prior_bbox * num_classes,
                      kernel_size=3,
                      padding=1),
            nn.Conv2d(in_channels=1024,
                      out_channels=num_prior_bbox * num_classes,
                      kernel_size=3,
                      padding=1),
            nn.Conv2d(in_channels=512,
                      out_channels=num_prior_bbox * num_classes,
                      kernel_size=3,
                      padding=1),
            nn.Conv2d(in_channels=256,
                      out_channels=num_prior_bbox * num_classes,
                      kernel_size=3,
                      padding=1),
            nn.Conv2d(in_channels=256,
                      out_channels=num_prior_bbox * num_classes,
                      kernel_size=3,
                      padding=1),
            nn.Conv2d(in_channels=256,
                      out_channels=num_prior_bbox * num_classes,
                      kernel_size=3,
                      padding=1),
        ])

        # Load the pre-trained model for self.base_net, it will increase the accuracy by fine-tuning
        basenet_state = torch.load('pretrained/mobienetv2.pth',
                                   map_location='cpu')
        base_net_1 = {
            key: value
            for key, value in basenet_state.items() if 'base_net' in key
        }
        self.base_net.load_state_dict(base_net_1)
        layer_idx = 0

        def init_with_xavier(m):
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)

        self.loc_regressor.apply(init_with_xavier)
        self.classifier.apply(init_with_xavier)
        self.additional_feat_extractor.apply(init_with_xavier)
         model = Resnet_interpretable_gradcam(num_classes=num_classe)
     elif args.model_type == 'ex_gradcam2':
         model = VGG_interpretable_gradcam2(num_classes=num_classe)
     else:
         model = Resnet(num_classes=num_classe)
 elif args.model == 'mobilenet':
     if args.model_type == 'ex_atten':
         model = VGG_interpretable_atten(num_classes=num_classe)
     elif args.model_type == 'ex':
         model = VGG_interpretable(num_classes=num_classe)
     elif args.model_type == 'ex_gradcam':
         model = Mobile_interpretable_gradcam(num_classes=num_classe)
     elif args.model_type == 'ex_gradcam2':
         model = VGG_interpretable_gradcam2(num_classes=num_classe)
     else:
         model = MobileNet(num_classes=num_classe)
 elif args.model == 'alexnet':
     if args.model_type == 'ex_atten':
         model = VGG_interpretable_atten(num_classes=num_classe)
     elif args.model_type == 'ex':
         model = Alexnet_interpretable(num_classes=num_classe)
     elif args.model_type == 'ex_gradcam':
         model = Alexnet_interpretable_gradcam(num_classes=num_classe)
     else:
         model = Alexnet(num_classes=num_classe)
 use_gpu = torch.cuda.is_available()  # 判断是否有GPU加速
 if use_gpu:
     model = model.cuda()
 if model_half:
     model = model.half()
 if args.model_init:
def load_model (args):

	if args.output_layer == '0':
		if args.model == 'inception':
			model = InceptionV3(include_top=False, weights='imagenet', pooling=args.pooling)
			preprocess_mode='tf'
		elif args.model == 'xception':
			model = Xception(include_top=False, weights='imagenet', pooling=args.pooling)
			preprocess_mode='tf'
		elif args.model == 'inceptionresnet':
			model = InceptionResNetV2(include_top=False, weights='imagenet', pooling=args.pooling)
			preprocess_mode='tf'
		elif args.model == 'mobilenet':
			model = MobileNet(include_top=False, weights='imagenet', pooling=args.pooling)
			preprocess_mode='tf'
		elif args.model == 'mobilenet2':	
			model = MobileNetV2(include_top=False, weights='imagenet', pooling=args.pooling)
			preprocess_mode='tf'
		elif args.model == 'nasnet':	
			model = NASNetLarge(include_top=False, weights='imagenet', pooling=args.pooling)
			preprocess_mode='tf'
		elif args.model == 'resnet':
			model = ResNet50(include_top=False, weights='imagenet', pooling=args.pooling)
			preprocess_mode='caffe'
		elif args.model == 'vgg16':
			model = VGG16(include_top=False, weights='imagenet', pooling=args.pooling)
			preprocess_mode='caffe'
		elif args.model == 'vgg19':
			model = VGG19(include_top=False, weights='imagenet', pooling=args.pooling)
			preprocess_mode='caffe'
		else:
			print ("Model not found")
			return 0
	else:
		if args.model == 'inception':
			base_model = InceptionV3(include_top=False, weights='imagenet', pooling=args.pooling)
			model = Model(input=base_model.input, output=base_model.get_layer(args.output_layer).output)
			preprocess_mode='tf'
		elif args.model == 'xception':
			base_model = Xception(include_top=False, weights='imagenet', pooling=args.pooling)
			model = Model(input=base_model.input, output=base_model.get_layer(args.output_layer).output)
			preprocess_mode='tf'
		elif args.model == 'inceptionresnet':
			base_model = InceptionResNetV2(include_top=False, weights='imagenet', pooling=args.pooling)
			model = Model(input=base_model.input, output=base_model.get_layer(args.output_layer).output)
			preprocess_mode='tf'
		elif args.model == 'mobilenet':
			base_model = MobileNet(include_top=False, weights='imagenet', pooling=args.pooling)
			model = Model(input=base_model.input, output=base_model.get_layer(args.output_layer).output)
			preprocess_mode='tf'
		elif args.model == 'mobilenet2':	
			base_model = MobileNetV2(include_top=False, weights='imagenet', pooling=args.pooling)
			model = Model(input=base_model.input, output=base_model.get_layer(args.output_layer).output)
			preprocess_mode='tf'
		elif args.model == 'nasnet':	
			base_model = NASNetLarge(include_top=False, weights='imagenet', pooling=args.pooling)
			model = Model(input=base_model.input, output=base_model.get_layer(args.output_layer).output)
			preprocess_mode='tf'
		elif args.model == 'resnet':
			base_model = ResNet50(include_top=False, weights='imagenet', pooling=args.pooling)
			model = Model(input=base_model.input, output=base_model.get_layer(args.output_layer).output)
			preprocess_mode='caffe'
		elif args.model == 'vgg16':
			base_model = VGG16(include_top=False, weights='imagenet', pooling=args.pooling)
			model = Model(input=base_model.input, output=base_model.get_layer(args.output_layer).output)
			preprocess_mode='caffe'
		elif args.model == 'vgg19':
			base_model = VGG19(include_top=False, weights='imagenet', pooling=args.pooling)
			model = Model(input=base_model.input, output=base_model.get_layer(args.output_layer).output)
			preprocess_mode='caffe'
		else:
			print ("Model not found")
			return 0


	return model,preprocess_mode