def load_model(args, num_class, trained_param=None): if 'ResNet' in args.arch: arch = int(args.arch.split('-')[1]) model = ResNet.Model(num_layers=arch, num_class=num_class, name='ResNet', trainable=True) if trained_param is not None: with open(trained_param, 'rb') as f: trained = pickle.load(f) n = 0 for k in model.Layers.keys(): layer = model.Layers[k] if 'conv' in k or 'fc' in k: kernel = trained[layer.name + '/kernel:0'] layer.kernel_initializer = tf.constant_initializer(kernel) n += 1 if layer.use_biases: layer.biases_initializer = tf.constant_initializer( trained[layer.name + '/biases:0']) n += 1 layer.num_outputs = kernel.shape[-1] elif 'bn' in k: moving_mean = trained[layer.name + '/moving_mean:0'] moving_variance = trained[layer.name + '/moving_variance:0'] param_initializers = { 'moving_mean': tf.constant_initializer(moving_mean), 'moving_variance': tf.constant_initializer(moving_variance) } n += 2 if layer.scale: param_initializers['gamma'] = tf.constant_initializer( trained[layer.name + '/gamma:0']) n += 1 if layer.center: param_initializers['beta'] = tf.constant_initializer( trained[layer.name + '/beta:0']) n += 1 layer.param_initializers = param_initializers print(n, 'params loaded') return model
args.input_size = list(train_images.shape[1:]) if 'WResNet' in args.arch: arch = [int(a) for a in args.arch.split('-')[1:]] model = WResNet.Model(architecture=arch, num_class=np.max(train_labels) + 1, name='Student', trainable=True) elif 'VGG' in args.arch: model = VGG.Model(num_class=np.max(train_labels) + 1, name='Student', trainable=True) elif 'ResNet' in args.arch: arch = int(args.arch.split('-')[1]) model = ResNet.Model(num_layers=arch, num_class=np.max(train_labels) + 1, name='Student', trainable=True) elif 'Mobilev2' in args.arch: model = Mobilev2.Model(num_class=np.max(train_labels) + 1, width_mul=1.0 if args.slimmable else 1.0, name='Student', trainable=True) model(np.zeros([1] + args.input_size, dtype=np.float32), training=False) cardinality = tf.data.experimental.cardinality(datasets['train']).numpy() if args.decay_points is None: LR = tf.keras.optimizers.schedules.ExponentialDecay(args.learning_rate, cardinality, args.decay_rate, staircase=True)