コード例 #1
0
ファイル: train.py プロジェクト: langlrsw/MEED
def MEED(train=True):

    flag_train_app = True
    flag_with_y = True

    print('Loading dataset...')
    x_train, y_train, x_val, y_val, mx = load_data()
    mx = mx * 0.

    input_shape = x_train[0].shape
    num_classes = y_train.shape[1]

    with session as sess:

        M = create_original_model(input_shape, num_classes)
        weights_name = [
            i for i in os.listdir('./models') if i.startswith('original')
        ][0]
        M.load_weights('./models/' + weights_name, by_name=True)

        pred_train = M.predict(x_train, batch_size=32)
        pred_val = M.predict(x_val, batch_size=32)

        x_placeholder = tf.placeholder(
            dtype=tf.float32,
            shape=[None, input_shape[0], input_shape[1], input_shape[2]])
        y_placeholder = tf.placeholder(dtype=tf.float32,
                                       shape=[None, num_classes])

        # 解释模型
        # 梯度求和
        grad_sum = gradient_explain(M, x_placeholder, y_placeholder)

        print('Creating model...')
        model_pos, model_neg, _, As, Au, E = build_model(
            input_shape, num_classes, k, flag_with_y)

        train_acc = np.mean(
            np.argmax(pred_train, axis=1) == np.argmax(y_train, axis=1))
        val_acc = np.mean(
            np.argmax(pred_val, axis=1) == np.argmax(y_val, axis=1))
        print(
            'The train and validation accuracy of the original model is {} and {}'
            .format(train_acc, val_acc))

        if train:
            epochs = 50
            batch_size = 32
            data_reader_train = reader_vector_simple_rand.Reader(
                x_train,
                pred_train,
                batch_size=batch_size,
                flag_shuffle=True,
                rng_seed=123)

            n_step = int(x_train.shape[0] * epochs / batch_size)
            step_list_sub = np.array([1, 2, 5]).astype(int) * 10
            step_list = []
            for i_ratio in range(20):
                step_list.extend(step_list_sub)
                step_list_sub = step_list_sub * 10
            step_list.append(n_step - 1)

            i_step = -1
            while i_step < n_step:
                i_step += 1

                if True:

                    x_batch, y_batch = data_reader_train.iterate_batch()
                    if x_batch.shape[0] != batch_size:
                        continue

                    if i_step % 2 == 0:
                        flag_switch = True
                    else:
                        flag_switch = False

                    y_batch_one_hot = to_categorical(
                        np.argmax(y_batch, axis=1), num_classes)

                    selection_prior = sess.run(grad_sum,
                                               feed_dict={
                                                   x_placeholder: x_batch,
                                                   y_placeholder:
                                                   y_batch_one_hot
                                               })
                    selection_prior = np.reshape(selection_prior,
                                                 [x_batch.shape[0], -1])
                    selection_prior = softmax(-selection_prior)
                    selection_prior = np.log(selection_prior + 1e-40)
                    selection_prior = np.reshape(selection_prior, [
                        x_batch.shape[0], x_batch.shape[1] / 4,
                        x_batch.shape[2] / 4, 1
                    ])
                    epoch_cur = 1.  #float(1e5)#i_step#i_step*x_batch.shape[0]/float(x_train.shape[0])#float(1e8)#
                    epoch_cur = epoch_cur * np.ones(
                        (x_batch.shape[0], 1, 1, 1))

                    mean_x = np.repeat(mx, x_batch.shape[0], axis=0)

                    flag_random = False
                    batch_size_cur = x_batch.shape[0]
                    n_chunk = 8
                    for _ in range(1):
                        zzz = y_batch
                        Theta = gen_vectors_max_inter_angle(
                            num_classes,
                            1,
                            M=batch_size_cur * n_chunk,
                            flag_random=flag_random)
                        for i_chunk in range(n_chunk):
                            zzz = np.hstack([
                                zzz,
                                Theta[i_chunk * batch_size_cur:(i_chunk + 1) *
                                      batch_size_cur, :]
                            ])
                        if flag_with_y:
                            model_pos.train_on_batch([
                                x_batch, y_batch, mean_x, selection_prior,
                                epoch_cur
                            ], [y_batch, zzz])
                        else:
                            model_pos.train_on_batch(
                                [x_batch, mean_x, selection_prior, epoch_cur],
                                [y_batch, zzz])
                    for _ in range(1):
                        zzz = y_batch
                        Theta = gen_vectors_max_inter_angle(
                            num_classes,
                            1,
                            M=batch_size_cur * n_chunk,
                            flag_random=flag_random)
                        for i_chunk in range(n_chunk):
                            zzz = np.hstack([
                                zzz,
                                Theta[i_chunk * batch_size_cur:(i_chunk + 1) *
                                      batch_size_cur, :]
                            ])
                        if flag_with_y:
                            model_neg.train_on_batch([
                                x_batch, y_batch, mean_x, selection_prior,
                                epoch_cur
                            ], [y_batch, zzz, epoch_cur])
                        else:
                            model_neg.train_on_batch(
                                [x_batch, mean_x, selection_prior, epoch_cur],
                                [y_batch, zzz, epoch_cur])

                if i_step in step_list:
                    print(
                        '------------------------ test ----------------------------'
                    )

                    st = time.time()
                    pred_val = M.predict(x_val, batch_size=2048)
                    duration = time.time() - st
                    print('TPS = {}'.format(duration / x_val.shape[0]))

                    fidelity, infidelity = eval_without_approximator(
                        E,
                        M,
                        x_val,
                        pred_val,
                        flag_with_y=flag_with_y,
                        k=k,
                        mx=mx)
                    print('step: %d\tFS-M=%.4f\tFU-M=%.4f' %
                          (i_step, fidelity, infidelity))

                    if flag_train_app:
                        model_As = get_approximator_eval(
                            input_shape, num_classes)
                        model_Au = get_approximator_eval(
                            input_shape, num_classes)
                        fidelity, infidelity, _, _ = eval_with_approximator(
                            E,
                            model_As,
                            model_Au,
                            x_train,
                            pred_train,
                            x_val,
                            pred_val,
                            epochs=12,
                            flag_with_y=flag_with_y,
                            k=k)
                        print('step: %d\tFS-A=%.4f\tFU-A=%.4f' %
                              (i_step, fidelity, infidelity))

                    sen_N = 50
                    epsilon = 0.2
                    if flag_with_y == False:
                        pred_val = None
                    sensitivity = eval_sensitivity2(x_val,
                                                    E,
                                                    epsilon,
                                                    sen_N,
                                                    y=pred_val,
                                                    selection=None)
                    print('step: %d, SEN=%g' % (i_step, sensitivity))

    return
コード例 #2
0
def MEED(train=True):

    epochs = 20
    batch_size = 32
    flag_train_app = True

    print('Loading dataset...')
    x_train, y_train, x_val, y_val, id_to_word = load_data()

    with session as sess:
        original_model, model_predict = get_original_model()
        if False:
            original_model.fit(x_train,
                               y_train,
                               validation_data=(x_val, y_val),
                               epochs=epochs,
                               batch_size=batch_size)
            original_model.save('models/given_model_train.h5')
            model_predict.save('models/given_model_pred.h5')
        else:

            original_model = load_model('models/given_model_train.h5')
            model_predict = load_model('models/given_model_pred.h5')

        M = original_model

        pred_train = M.predict(x_train, batch_size=256)
        pred_val = M.predict(x_val, batch_size=256)

        x_placeholder = tf.placeholder(dtype=tf.int32, shape=[None, maxlen])
        y_placeholder = tf.placeholder(dtype=tf.float32, shape=[None, 2])
        grad_sum = gradient_explain(model_predict, x_placeholder,
                                    y_placeholder)

        print('Creating model...')
        model_pos, model_neg, _, As, Au, E = build_model(
            maxlen, max_features, embedding_dims, k)

        train_acc = np.mean(
            np.argmax(pred_train, axis=1) == np.argmax(y_train, axis=1))
        val_acc = np.mean(
            np.argmax(pred_val, axis=1) == np.argmax(y_val, axis=1))
        print(
            'The train and validation accuracy of the original model is {} and {}'
            .format(train_acc, val_acc))

        if train:

            data_reader_train = reader_vector_simple_rand.Reader(
                x_train,
                pred_train,
                batch_size=batch_size,
                flag_shuffle=True,
                rng_seed=123)

            n_step = int(x_train.shape[0] * epochs / batch_size)
            step_list_sub = np.array([1, 2, 5]).astype(int) * 100
            step_list = []
            for i_ratio in range(20):
                step_list.extend(step_list_sub)
                step_list_sub = step_list_sub * 10
            step_list.append(n_step - 1)

            i_step = -1
            while i_step < n_step:
                i_step += 1

                if True:

                    x_batch, y_batch = data_reader_train.iterate_batch()
                    if x_batch.shape[0] != batch_size:
                        continue

                    y_batch = M.predict(x_batch)
                    y_batch_one_hot = to_categorical(
                        np.argmax(y_batch, axis=1), 2)

                    selection_prior = sess.run(grad_sum,
                                               feed_dict={
                                                   x_placeholder: x_batch,
                                                   y_placeholder:
                                                   y_batch_one_hot
                                               })

                    selection_prior = softmax(-selection_prior)
                    selection_prior = np.log(selection_prior + 1e-40)
                    selection_prior = selection_prior[:, :, np.newaxis]
                    epoch_cur = np.ceil(i_step * batch_size /
                                        float(x_train.shape[0]))
                    epoch_cur = epoch_cur * np.ones((x_batch.shape[0], 1, 1))

                    model_pos.train_on_batch(
                        [x_batch, y_batch, selection_prior, epoch_cur],
                        [y_batch, y_batch])
                    model_neg.train_on_batch(
                        [x_batch, y_batch, selection_prior, epoch_cur],
                        [y_batch, 1. - y_batch, epoch_cur])

                if i_step in step_list:

                    print(
                        '------------------------ test ----------------------------'
                    )
                    pred_train = M.predict(x_train, batch_size=2048)

                    st = time.time()
                    pred_val = M.predict(x_val, batch_size=2048)
                    duration = time.time() - st
                    print('TPS = {}'.format(duration / x_val.shape[0]))

                    fidelity, infidelity = eval_without_approximator(
                        E,
                        M,
                        x_val,
                        pred_val,
                        flag_with_y=True,
                        k=k,
                        id_to_word=id_to_word)
                    print('step: %d\tFS-M=%.4f\tFU-M=%.4f' %
                          (i_step, fidelity, infidelity))

                    if flag_train_app:

                        model_As = get_approximator_eval(
                            maxlen, max_features, embedding_dims)
                        model_Au = get_approximator_eval(
                            maxlen, max_features, embedding_dims)

                        fidelity, infidelity, model_As, model_Au = eval_with_approximator(
                            E,
                            model_As,
                            model_Au,
                            x_train,
                            pred_train,
                            x_val,
                            pred_val,
                            epochs=5,
                            flag_with_y=True,
                            k=k)
                        print('step: %d\tFS-A=%.4f\tFU-A=%.4f' %
                              (i_step, fidelity, infidelity))

    return