예제 #1
0
def test_DIN_att():
    model_name = "DIN_att"

    x, y, feature_dim_dict, behavior_feature_list = get_xy_fd()

    model = DIN(
        feature_dim_dict,
        behavior_feature_list,
        hist_len_max=4,
        embedding_size=8,
        use_din=True,
        hidden_size=[4, 4, 4],
        keep_prob=0.6,
    )

    model.compile('adam',
                  'binary_crossentropy',
                  metrics=['binary_crossentropy'])
    model.fit(x, y, verbose=1, validation_split=0.5)

    print(model_name + " test train valid pass!")
    model.save_weights(model_name + '_weights.h5')
    model.load_weights(model_name + '_weights.h5')
    print(model_name + " test save load weight pass!")

    # try:
    #     save_model(model,  name + '.h5')
    #     model = load_model(name + '.h5', custom_objects)
    #     print(name + " test save load model pass!")
    # except:
    #     print("【Error】There is a bug when save model use Dice---------------------------------------------------")

    print(model_name + " test pass!")
예제 #2
0
파일: DIN_test.py 프로젝트: nwf5d/DeepCTR
def test_DIN_sum():

    model_name = "DIN_sum"
    x, y, feature_dim_dict, behavior_feature_list = get_xy_fd()

    model = DIN(feature_dim_dict,
                behavior_feature_list,
                hist_len_max=4,
                embedding_size=8,
                use_din=False,
                hidden_size=[4, 4, 4],
                keep_prob=0.6,
                activation="sigmoid")

    model.compile('adam',
                  'binary_crossentropy',
                  metrics=['binary_crossentropy'])
    model.fit(x, y, verbose=1, validation_split=0.5)

    print(model_name + " test train valid pass!")
    model.save_weights(model_name + '_weights.h5')
    model.load_weights(model_name + '_weights.h5')
    print(model_name + " test save load weight pass!")

    save_model(model, model_name + '.h5')
    model = load_model(model_name + '.h5', custom_objects)
    print(model_name + " test save load model pass!")

    print(model_name + " test pass!")
예제 #3
0
def test_DIN_att():
    model_name = "DIN_att"

    x, y, feature_dim_dict, behavior_feature_list = get_xy_fd()

    model = DIN(feature_dim_dict, behavior_feature_list, hist_len_max=4, embedding_size=8,
                use_din=True, hidden_size=[4, 4, 4], keep_prob=0.6,)

    model.compile('adam', 'binary_crossentropy',
                  metrics=['binary_crossentropy'])
    model.fit(x, y, verbose=1, validation_split=0.5)

    print(model_name+" test train valid pass!")
    model.save_weights(model_name + '_weights.h5')
    model.load_weights(model_name + '_weights.h5')
    print(model_name+" test save load weight pass!")

    # try:
    #     save_model(model,  name + '.h5')
    #     model = load_model(name + '.h5', custom_objects)
    #     print(name + " test save load model pass!")
    # except:
    #     print("【Error】There is a bug when save model use Dice---------------------------------------------------")

    print(model_name + " test pass!")
예제 #4
0
                att_activation='dice',
                att_weight_normalization=False,
                hist_len_max=sess_len_max,
                dnn_hidden_units=(200, 80),
                att_hidden_size=(
                    64,
                    16,
                ))
    model.compile('adagrad',
                  'binary_crossentropy',
                  metrics=['binary_crossentropy'])
    model_dir = "../model_dir_" + str(EMBEDDING_SIZE)
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)
    if os.path.exists(model_dir + '/ckpt.h5'):
        model.load_weights(model_dir + '/ckpt.h5')
    """
    test_input_pos = pd.read_pickle(
          '../model_input/din_input_'+test_date+'.pkl')
    test_input_neg = pd.read_pickle(
          '../model_input/din_input_'+test_date+'_neg.pkl')
    test_input = []
    for i in range(len(test_input_pos)):
      feature_input = np.concatenate([test_input_pos[i],test_input_neg[i]],axis=0)
      test_input.append(feature_input)
      # model_input = np.concatenate([model_input_pos,model_input_neg],axis=1)
    test_label_pos = pd.read_pickle('../model_input/din_label_'+test_date+'.pkl')
    test_label_neg = pd.read_pickle('../model_input/din_label_'+test_date+'_neg.pkl')
    test_label = np.concatenate([test_label_pos,test_label_neg],axis=-1)

    del test_input_pos