예제 #1
0
    def build(self, input_shape):
        base_model = self._default_resnet50()

        ext_config = self._extend_config(base_model.get_config())
        ext_weights = self._extend_weights(base_model.get_weights())

        ext_model = Model.from_config(ext_config)
        ext_model.set_weights(ext_weights)

        self.backbone = ext_model

        super().build(input_shape)
def wider(conv1, conv2, old_model, add_filter):
    w_conv1, b_conv1 = old_model.get_layer(conv1).get_weights()
    w_conv2, b_conv2 = old_model.get_layer(conv2).get_weights()
    add_num = round(add_filter * w_conv1.shape[3])
    if add_num == 0:
        return old_model
    index = np.random.randint(w_conv1.shape[3], size=add_num)
    factors = np.bincount(index)[index] + 1.
    tmp_w1 = w_conv1[:, :, :, index]
    tmp_b1 = b_conv1[index]
    tmp_w2 = w_conv2[:, :, index, :] / factors.reshape((1, 1, -1, 1))
    noise = np.random.normal(0, 5e-2 * tmp_w2.std(), size=tmp_w2.shape)
    new_w_conv1 = np.concatenate((w_conv1, tmp_w1), axis=3)
    new_w_conv2 = np.concatenate((w_conv2, tmp_w2 + noise), axis=2)
    new_w_conv2[:, :, index, :] = tmp_w2
    new_b_conv1 = np.concatenate((b_conv1, tmp_b1), axis=0)
    model_config = old_model.get_config()
    tmp_name = ''
    shotcut_name = ''
    next_conv = ''

    for one in model_config['layers']:
        if one['config']['name'] == conv1:
            one['config']['filters'] += add_num
            break
        for index2, one_2 in enumerate(model_config['layers']):
            if one_2['config']['name'] == conv1:
                if model_config['layers'][
                        index2 + 1]['class_name'] == 'BatchNormalization':
                    tmp_name = model_config['layers'][index2 + 1]['name']
                elif model_config['layers'][index2 + 1]['class_name'] == 'Add':
                    tmp_name = model_config['layers'][index2 + 2]['name']
                    shotcut_name = model_config['layers'][index2 - 1]['name']
                    next_conv = model_config['layers'][index2 + 7]['name']
                break

    for one in model_config['layers']:
        if one['config']['name'] == shotcut_name:
            one['config']['filters'] += add_num
            break

    a, b, c, d = old_model.get_layer(tmp_name).get_weights()
    tmp_a = a[index]
    tmp_b = b[index]
    tmp_c = c[index]
    tmp_d = d[index]
    new_a = np.concatenate((a, tmp_a), axis=0)
    new_b = np.concatenate((b, tmp_b), axis=0)
    new_c = np.concatenate((c, tmp_c), axis=0)
    new_d = np.concatenate((d, tmp_d), axis=0)

    if shotcut_name != '':
        w_1, b_1 = old_model.get_layer(shotcut_name).get_weights()
        w_2, b_2 = old_model.get_layer(next_conv).get_weights()
        t_w1 = w_1[:, :, :, index]
        t_b1 = b_1[index]
        t_w2 = w_2[:, :, index, :] / factors.reshape((1, 1, -1, 1))
        noise = np.random.normal(0, 5e-2 * t_w2.std(), size=t_w2.shape)
        new_w_1 = np.concatenate((w_1, t_w1), axis=3)
        new_w_2 = np.concatenate((w_2, t_w2 + noise), axis=2)
        new_w_2[:, :, index, :] = t_w2
        new_b_1 = np.concatenate((b_1, t_b1), axis=0)

    new_model = Model.from_config(model_config)
    for one_layer in new_model.layers:
        if one_layer.name == conv1:
            new_model.get_layer(conv1).set_weights([new_w_conv1, new_b_conv1])
        elif one_layer.name == conv2:
            new_model.get_layer(conv2).set_weights([new_w_conv2, b_conv2])
        elif one_layer.name == tmp_name:
            new_model.get_layer(tmp_name).set_weights(
                [new_a, new_b, new_c, new_d])
        elif one_layer.name == shotcut_name:
            new_model.get_layer(shotcut_name).set_weights([new_w_1, new_b_1])
        elif one_layer.name == next_conv:
            new_model.get_layer(next_conv).set_weights([new_w_2, b_2])
        else:
            new_model.get_layer(one_layer.name).set_weights(
                old_model.get_layer(one_layer.name).get_weights())
    return new_model
def wider_last(conv1, old_model, add_filter):
    w_conv1, b_conv1 = old_model.get_layer(conv1).get_weights()
    add_num = round(add_filter * w_conv1.shape[3])
    if add_num == 0:
        return old_model
    index = np.random.randint(w_conv1.shape[3], size=add_num)
    tmp_w1 = w_conv1[:, :, :, index]
    tmp_b1 = b_conv1[index]
    new_w_conv1 = np.concatenate((w_conv1, tmp_w1), axis=3)
    new_b_conv1 = np.concatenate((b_conv1, tmp_b1), axis=0)
    model_config = old_model.get_config()
    tmp_name = ''
    shotcut_name = ''
    for one in model_config['layers']:
        if one['config']['name'] == conv1:
            one['config']['filters'] += add_num
            break
        for index2, one_2 in enumerate(model_config['layers']):
            if one_2['config']['name'] == conv1:
                if model_config['layers'][
                        index2 + 1]['class_name'] == 'BatchNormalization':
                    tmp_name = model_config['layers'][index2 + 1]['name']
                elif model_config['layers'][index2 + 1]['class_name'] == 'Add':
                    tmp_name = model_config['layers'][index2 + 2]['name']
                    shotcut_name = model_config['layers'][index2 - 1]['name']
                break
    for one in model_config['layers']:
        if one['config']['name'] == shotcut_name:
            one['config']['filters'] += add_num
            break
    a, b, c, d = old_model.get_layer(tmp_name).get_weights()
    tmp_a = a[index]
    tmp_b = b[index]
    tmp_c = c[index]
    tmp_d = d[index]
    new_a = np.concatenate((a, tmp_a), axis=0)
    new_b = np.concatenate((b, tmp_b), axis=0)
    new_c = np.concatenate((c, tmp_c), axis=0)
    new_d = np.concatenate((d, tmp_d), axis=0)

    if shotcut_name != '':
        w_1, b_1 = old_model.get_layer(shotcut_name).get_weights()
        t_w1 = w_1[:, :, :, index]
        t_b1 = b_1[index]
        new_w_1 = np.concatenate((w_1, t_w1), axis=3)
        new_b_1 = np.concatenate((b_1, t_b1), axis=0)

    w, b = old_model.get_layer('dense_1').get_weights()
    zero = np.zeros((add_num, w.shape[1]))
    new_w1 = np.concatenate((w, zero), axis=0)

    new_model = Model.from_config(model_config)
    for one_layer in new_model.layers:
        if one_layer.name == conv1:
            new_model.get_layer(conv1).set_weights([new_w_conv1, new_b_conv1])
        elif one_layer.name == tmp_name:
            new_model.get_layer(tmp_name).set_weights(
                [new_a, new_b, new_c, new_d])
        elif one_layer.name == shotcut_name:
            new_model.get_layer(shotcut_name).set_weights([new_w_1, new_b_1])
        elif one_layer.name == 'dense_1':
            new_model.get_layer('dense_1').set_weights([new_w1, b])
        else:
            new_model.get_layer(one_layer.name).set_weights(
                old_model.get_layer(one_layer.name).get_weights())
    return new_model
예제 #4
0
def train_model_incrementally(old_model_file,
                              X,
                              Y,
                              output_file,
                              batch_size=8,
                              epochs=50,
                              λ=100000000,
                              fisher_samples=200):
    """
    Take a model which already includes an estimate of its Fisher information matrix
    and trains it more using EWC.

    old_model_file:
        the path to an HDF5 file containing weights in Keras format as well as the Fisher
        information matrix
    X:
        inputs for training
    Y:
        outputs for training
    output_file:
        The HDF5 file for output (again containing weights and Fisher information matrix, so
        this can be used in a chain)
    batch_size: integer
        The batch size for model training -- passed directly to the model
    epochs: integer
        The number of epochs for model training -- passed directly to the model
    λ: float
        The hyperparameter for EWC. Higher values make it stay closer to old weights, lower
        values let it optimize more for the current task. I've had good experience with values
        near the default of 10^8.
    fisher_samples: integer
        The number of samples to use for estimating the Fisher information matrix.
    """
    # Load the old model
    old_model = keras.models.load_model(
        old_model_file, custom_objects={'dice_coef': dice_coef})
    # Load the old Fisher diagonal
    fisher = []
    with h5py.File(old_model_file) as f:
        weight_keys = sorted(f['fisher_diagonal'].keys())
        for key in weight_keys:
            fisher.append(np.array(f['fisher_diagonal'][key]))
    # Create the new loss function
    def ewc_loss(model):
        def loss(y_true, y_pred):
            standard_loss = binary_crossentropy(y_true, y_pred)
            ewc_term = 0
            for layerIndex in range(len(fisher)):
                Δweights = model.trainable_weights[
                    layerIndex] - old_model.get_weights()[layerIndex]
                ewc_term += K.sum((λ / 2) * fisher[layerIndex] * (Δweights)**2)
            return standard_loss + ewc_term

        return loss

    # Create the new model from that loss function
    new_model_ewc = Model.from_config(old_model.get_config())
    new_model_ewc.set_weights(deepcopy(old_model.get_weights()))
    new_model_ewc.compile(optimizer='adam',
                          loss=ewc_loss(new_model_ewc),
                          metrics=[dice_coef])
    # Train that model
    new_model_ewc.fit(X, Y, batch_size=batch_size, epochs=epochs)
    # Create a version of that model with a binary cross-entropy loss
    new_model_bce = Model.from_config(new_model_ewc.get_config())
    new_model_bce.set_weights(new_model_ewc.get_weights())
    new_model_bce.compile(optimizer='adam',
                          loss='binary_crossentropy',
                          metrics=[dice_coef])
    # Compute the new Fisher diagonal from that
    new_fisher = estimate_fisher_information(new_model_bce, X, Y,
                                             fisher_samples)
    # Save the model with binary cross-entropy loss (so no loading issues)
    new_model_bce.save(output_file)
    # Save the Fisher information matrix in the same file
    write_fisher(output_file, new_fisher)
    return new_model_bce