def load_model(args, num_class, trained_param=None):
    if 'ResNet' in args.arch:
        arch = int(args.arch.split('-')[1])
        model = ResNet.Model(num_layers=arch,
                             num_class=num_class,
                             name='ResNet',
                             trainable=True)

    if trained_param is not None:
        with open(trained_param, 'rb') as f:
            trained = pickle.load(f)
        n = 0
        for k in model.Layers.keys():
            layer = model.Layers[k]
            if 'conv' in k or 'fc' in k:
                kernel = trained[layer.name + '/kernel:0']
                layer.kernel_initializer = tf.constant_initializer(kernel)
                n += 1
                if layer.use_biases:
                    layer.biases_initializer = tf.constant_initializer(
                        trained[layer.name + '/biases:0'])
                    n += 1
                layer.num_outputs = kernel.shape[-1]

            elif 'bn' in k:
                moving_mean = trained[layer.name + '/moving_mean:0']
                moving_variance = trained[layer.name + '/moving_variance:0']
                param_initializers = {
                    'moving_mean': tf.constant_initializer(moving_mean),
                    'moving_variance': tf.constant_initializer(moving_variance)
                }
                n += 2
                if layer.scale:
                    param_initializers['gamma'] = tf.constant_initializer(
                        trained[layer.name + '/gamma:0'])
                    n += 1
                if layer.center:
                    param_initializers['beta'] = tf.constant_initializer(
                        trained[layer.name + '/beta:0'])
                    n += 1
                layer.param_initializers = param_initializers
        print(n, 'params loaded')
    return model
Beispiel #2
0
    args.input_size = list(train_images.shape[1:])

    if 'WResNet' in args.arch:
        arch = [int(a) for a in args.arch.split('-')[1:]]
        model = WResNet.Model(architecture=arch,
                              num_class=np.max(train_labels) + 1,
                              name='Student',
                              trainable=True)
    elif 'VGG' in args.arch:
        model = VGG.Model(num_class=np.max(train_labels) + 1,
                          name='Student',
                          trainable=True)
    elif 'ResNet' in args.arch:
        arch = int(args.arch.split('-')[1])
        model = ResNet.Model(num_layers=arch,
                             num_class=np.max(train_labels) + 1,
                             name='Student',
                             trainable=True)
    elif 'Mobilev2' in args.arch:
        model = Mobilev2.Model(num_class=np.max(train_labels) + 1,
                               width_mul=1.0 if args.slimmable else 1.0,
                               name='Student',
                               trainable=True)

    model(np.zeros([1] + args.input_size, dtype=np.float32), training=False)

    cardinality = tf.data.experimental.cardinality(datasets['train']).numpy()
    if args.decay_points is None:
        LR = tf.keras.optimizers.schedules.ExponentialDecay(args.learning_rate,
                                                            cardinality,
                                                            args.decay_rate,
                                                            staircase=True)