def build(self, input_shape): base_model = self._default_resnet50() ext_config = self._extend_config(base_model.get_config()) ext_weights = self._extend_weights(base_model.get_weights()) ext_model = Model.from_config(ext_config) ext_model.set_weights(ext_weights) self.backbone = ext_model super().build(input_shape)
def wider(conv1, conv2, old_model, add_filter): w_conv1, b_conv1 = old_model.get_layer(conv1).get_weights() w_conv2, b_conv2 = old_model.get_layer(conv2).get_weights() add_num = round(add_filter * w_conv1.shape[3]) if add_num == 0: return old_model index = np.random.randint(w_conv1.shape[3], size=add_num) factors = np.bincount(index)[index] + 1. tmp_w1 = w_conv1[:, :, :, index] tmp_b1 = b_conv1[index] tmp_w2 = w_conv2[:, :, index, :] / factors.reshape((1, 1, -1, 1)) noise = np.random.normal(0, 5e-2 * tmp_w2.std(), size=tmp_w2.shape) new_w_conv1 = np.concatenate((w_conv1, tmp_w1), axis=3) new_w_conv2 = np.concatenate((w_conv2, tmp_w2 + noise), axis=2) new_w_conv2[:, :, index, :] = tmp_w2 new_b_conv1 = np.concatenate((b_conv1, tmp_b1), axis=0) model_config = old_model.get_config() tmp_name = '' shotcut_name = '' next_conv = '' for one in model_config['layers']: if one['config']['name'] == conv1: one['config']['filters'] += add_num break for index2, one_2 in enumerate(model_config['layers']): if one_2['config']['name'] == conv1: if model_config['layers'][ index2 + 1]['class_name'] == 'BatchNormalization': tmp_name = model_config['layers'][index2 + 1]['name'] elif model_config['layers'][index2 + 1]['class_name'] == 'Add': tmp_name = model_config['layers'][index2 + 2]['name'] shotcut_name = model_config['layers'][index2 - 1]['name'] next_conv = model_config['layers'][index2 + 7]['name'] break for one in model_config['layers']: if one['config']['name'] == shotcut_name: one['config']['filters'] += add_num break a, b, c, d = old_model.get_layer(tmp_name).get_weights() tmp_a = a[index] tmp_b = b[index] tmp_c = c[index] tmp_d = d[index] new_a = np.concatenate((a, tmp_a), axis=0) new_b = np.concatenate((b, tmp_b), axis=0) new_c = np.concatenate((c, tmp_c), axis=0) new_d = np.concatenate((d, tmp_d), axis=0) if shotcut_name != '': w_1, b_1 = old_model.get_layer(shotcut_name).get_weights() w_2, b_2 = old_model.get_layer(next_conv).get_weights() t_w1 = w_1[:, :, :, index] t_b1 = b_1[index] t_w2 = w_2[:, :, index, :] / factors.reshape((1, 1, -1, 1)) noise = np.random.normal(0, 5e-2 * t_w2.std(), size=t_w2.shape) new_w_1 = np.concatenate((w_1, t_w1), axis=3) new_w_2 = np.concatenate((w_2, t_w2 + noise), axis=2) new_w_2[:, :, index, :] = t_w2 new_b_1 = np.concatenate((b_1, t_b1), axis=0) new_model = Model.from_config(model_config) for one_layer in new_model.layers: if one_layer.name == conv1: new_model.get_layer(conv1).set_weights([new_w_conv1, new_b_conv1]) elif one_layer.name == conv2: new_model.get_layer(conv2).set_weights([new_w_conv2, b_conv2]) elif one_layer.name == tmp_name: new_model.get_layer(tmp_name).set_weights( [new_a, new_b, new_c, new_d]) elif one_layer.name == shotcut_name: new_model.get_layer(shotcut_name).set_weights([new_w_1, new_b_1]) elif one_layer.name == next_conv: new_model.get_layer(next_conv).set_weights([new_w_2, b_2]) else: new_model.get_layer(one_layer.name).set_weights( old_model.get_layer(one_layer.name).get_weights()) return new_model
def wider_last(conv1, old_model, add_filter): w_conv1, b_conv1 = old_model.get_layer(conv1).get_weights() add_num = round(add_filter * w_conv1.shape[3]) if add_num == 0: return old_model index = np.random.randint(w_conv1.shape[3], size=add_num) tmp_w1 = w_conv1[:, :, :, index] tmp_b1 = b_conv1[index] new_w_conv1 = np.concatenate((w_conv1, tmp_w1), axis=3) new_b_conv1 = np.concatenate((b_conv1, tmp_b1), axis=0) model_config = old_model.get_config() tmp_name = '' shotcut_name = '' for one in model_config['layers']: if one['config']['name'] == conv1: one['config']['filters'] += add_num break for index2, one_2 in enumerate(model_config['layers']): if one_2['config']['name'] == conv1: if model_config['layers'][ index2 + 1]['class_name'] == 'BatchNormalization': tmp_name = model_config['layers'][index2 + 1]['name'] elif model_config['layers'][index2 + 1]['class_name'] == 'Add': tmp_name = model_config['layers'][index2 + 2]['name'] shotcut_name = model_config['layers'][index2 - 1]['name'] break for one in model_config['layers']: if one['config']['name'] == shotcut_name: one['config']['filters'] += add_num break a, b, c, d = old_model.get_layer(tmp_name).get_weights() tmp_a = a[index] tmp_b = b[index] tmp_c = c[index] tmp_d = d[index] new_a = np.concatenate((a, tmp_a), axis=0) new_b = np.concatenate((b, tmp_b), axis=0) new_c = np.concatenate((c, tmp_c), axis=0) new_d = np.concatenate((d, tmp_d), axis=0) if shotcut_name != '': w_1, b_1 = old_model.get_layer(shotcut_name).get_weights() t_w1 = w_1[:, :, :, index] t_b1 = b_1[index] new_w_1 = np.concatenate((w_1, t_w1), axis=3) new_b_1 = np.concatenate((b_1, t_b1), axis=0) w, b = old_model.get_layer('dense_1').get_weights() zero = np.zeros((add_num, w.shape[1])) new_w1 = np.concatenate((w, zero), axis=0) new_model = Model.from_config(model_config) for one_layer in new_model.layers: if one_layer.name == conv1: new_model.get_layer(conv1).set_weights([new_w_conv1, new_b_conv1]) elif one_layer.name == tmp_name: new_model.get_layer(tmp_name).set_weights( [new_a, new_b, new_c, new_d]) elif one_layer.name == shotcut_name: new_model.get_layer(shotcut_name).set_weights([new_w_1, new_b_1]) elif one_layer.name == 'dense_1': new_model.get_layer('dense_1').set_weights([new_w1, b]) else: new_model.get_layer(one_layer.name).set_weights( old_model.get_layer(one_layer.name).get_weights()) return new_model
def train_model_incrementally(old_model_file, X, Y, output_file, batch_size=8, epochs=50, λ=100000000, fisher_samples=200): """ Take a model which already includes an estimate of its Fisher information matrix and trains it more using EWC. old_model_file: the path to an HDF5 file containing weights in Keras format as well as the Fisher information matrix X: inputs for training Y: outputs for training output_file: The HDF5 file for output (again containing weights and Fisher information matrix, so this can be used in a chain) batch_size: integer The batch size for model training -- passed directly to the model epochs: integer The number of epochs for model training -- passed directly to the model λ: float The hyperparameter for EWC. Higher values make it stay closer to old weights, lower values let it optimize more for the current task. I've had good experience with values near the default of 10^8. fisher_samples: integer The number of samples to use for estimating the Fisher information matrix. """ # Load the old model old_model = keras.models.load_model( old_model_file, custom_objects={'dice_coef': dice_coef}) # Load the old Fisher diagonal fisher = [] with h5py.File(old_model_file) as f: weight_keys = sorted(f['fisher_diagonal'].keys()) for key in weight_keys: fisher.append(np.array(f['fisher_diagonal'][key])) # Create the new loss function def ewc_loss(model): def loss(y_true, y_pred): standard_loss = binary_crossentropy(y_true, y_pred) ewc_term = 0 for layerIndex in range(len(fisher)): Δweights = model.trainable_weights[ layerIndex] - old_model.get_weights()[layerIndex] ewc_term += K.sum((λ / 2) * fisher[layerIndex] * (Δweights)**2) return standard_loss + ewc_term return loss # Create the new model from that loss function new_model_ewc = Model.from_config(old_model.get_config()) new_model_ewc.set_weights(deepcopy(old_model.get_weights())) new_model_ewc.compile(optimizer='adam', loss=ewc_loss(new_model_ewc), metrics=[dice_coef]) # Train that model new_model_ewc.fit(X, Y, batch_size=batch_size, epochs=epochs) # Create a version of that model with a binary cross-entropy loss new_model_bce = Model.from_config(new_model_ewc.get_config()) new_model_bce.set_weights(new_model_ewc.get_weights()) new_model_bce.compile(optimizer='adam', loss='binary_crossentropy', metrics=[dice_coef]) # Compute the new Fisher diagonal from that new_fisher = estimate_fisher_information(new_model_bce, X, Y, fisher_samples) # Save the model with binary cross-entropy loss (so no loading issues) new_model_bce.save(output_file) # Save the Fisher information matrix in the same file write_fisher(output_file, new_fisher) return new_model_bce